class DirectoryServiceBackend(BaseBackend): """Implementation of DirectoryService APIs.""" def __init__(self, region_name=None): self.region_name = region_name self.directories = {} self.tagger = TaggingService() def reset(self): """Re-initialize all attributes for this instance.""" region_name = self.region_name self.__dict__ = {} self.__init__(region_name) @staticmethod def default_vpc_endpoint_service(service_region, zones): """List of dicts representing default VPC endpoints for this service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "ds") @staticmethod def _verify_subnets(region, vpc_settings): """Verify subnets are valid, else raise an exception. If settings are valid, add AvailabilityZones to vpc_settings. """ if len(vpc_settings["SubnetIds"]) != 2: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones.") # Subnet IDs are checked before the VPC ID. The Subnet IDs must # be valid and in different availability zones. try: subnets = ec2_backends[region].get_all_subnets( subnet_ids=vpc_settings["SubnetIds"]) except InvalidSubnetIdError as exc: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones.") from exc regions = [subnet.availability_zone for subnet in subnets] if regions[0] == regions[1]: raise ClientException( "Invalid subnet ID(s). The two subnets must be in " "different Availability Zones.") vpcs = ec2_backends[region].describe_vpcs() if vpc_settings["VpcId"] not in [x.id for x in vpcs]: raise ClientException("Invalid VPC ID.") vpc_settings["AvailabilityZones"] = regions def connect_directory( self, region, name, short_name, password, description, size, connect_settings, tags, ): # pylint: disable=too-many-arguments """Create a fake AD Connector.""" if len(self.directories) > Directory.CONNECTED_DIRECTORIES_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CONNECTED_DIRECTORIES_LIMIT} directories may be created" ) validate_args([ ("password", password), ("size", size), ("name", name), ("description", description), ("shortName", short_name), ( "connectSettings.vpcSettings.subnetIds", connect_settings["SubnetIds"], ), ( "connectSettings.customerUserName", connect_settings["CustomerUserName"], ), ("connectSettings.customerDnsIps", connect_settings["CustomerDnsIps"]), ]) # ConnectSettings and VpcSettings both have a VpcId and Subnets. self._verify_subnets(region, connect_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( region, name, password, "ADConnector", size=size, connect_settings=connect_settings, short_name=short_name, description=description, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def create_directory(self, region, name, short_name, password, description, size, vpc_settings, tags): # pylint: disable=too-many-arguments """Create a fake Simple Ad Directory.""" if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created" ) # botocore doesn't look for missing vpc_settings, but boto3 does. if not vpc_settings: raise InvalidParameterException("VpcSettings must be specified.") validate_args([ ("password", password), ("size", size), ("name", name), ("description", description), ("shortName", short_name), ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]), ]) self._verify_subnets(region, vpc_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( region, name, password, "SimpleAD", size=size, vpc_settings=vpc_settings, short_name=short_name, description=description, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def _validate_directory_id(self, directory_id): """Raise an exception if the directory id is invalid or unknown.""" # Validation of ID takes precedence over a check for its existence. validate_args([("directoryId", directory_id)]) if directory_id not in self.directories: raise EntityDoesNotExistException( f"Directory {directory_id} does not exist") def create_alias(self, directory_id, alias): """Create and assign an alias to a directory.""" self._validate_directory_id(directory_id) # The default alias name is the same as the directory name. Check # whether this directory was already given an alias. directory = self.directories[directory_id] if directory.alias != directory_id: raise InvalidParameterException( "The directory in the request already has an alias. That " "alias must be deleted before a new alias can be created.") # Is the alias already in use? if alias in [x.alias for x in self.directories.values()]: raise EntityAlreadyExistsException( f"Alias '{alias}' already exists.") validate_args([("alias", alias)]) directory.update_alias(alias) return {"DirectoryId": directory_id, "Alias": alias} def create_microsoft_ad( self, region, name, short_name, password, description, vpc_settings, edition, tags, ): # pylint: disable=too-many-arguments """Create a fake Microsoft Ad Directory.""" if len(self.directories) > Directory.CLOUDONLY_MICROSOFT_AD_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CLOUDONLY_MICROSOFT_AD_LIMIT} directories may be created" ) # boto3 looks for missing vpc_settings for create_microsoft_ad(). validate_args([ ("password", password), ("edition", edition), ("name", name), ("description", description), ("shortName", short_name), ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]), ]) self._verify_subnets(region, vpc_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( region, name, password, "MicrosoftAD", vpc_settings=vpc_settings, short_name=short_name, description=description, edition=edition, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def delete_directory(self, directory_id): """Delete directory with the matching ID.""" self._validate_directory_id(directory_id) self.tagger.delete_all_tags_for_resource(directory_id) self.directories.pop(directory_id) return directory_id def disable_sso(self, directory_id, username=None, password=None): """Disable single-sign on for a directory.""" self._validate_directory_id(directory_id) validate_args([("ssoPassword", password), ("userName", username)]) directory = self.directories[directory_id] directory.enable_sso(False) def enable_sso(self, directory_id, username=None, password=None): """Enable single-sign on for a directory.""" self._validate_directory_id(directory_id) validate_args([("ssoPassword", password), ("userName", username)]) directory = self.directories[directory_id] if directory.alias == directory_id: raise ClientException( f"An alias is required before enabling SSO. DomainId={directory_id}" ) directory = self.directories[directory_id] directory.enable_sso(True) @paginate(pagination_model=PAGINATION_MODEL) def describe_directories(self, directory_ids=None, next_token=None, limit=0): # pylint: disable=unused-argument """Return info on all directories or directories with matching IDs.""" for directory_id in directory_ids or self.directories: self._validate_directory_id(directory_id) directories = list(self.directories.values()) if directory_ids: directories = [ x for x in directories if x.directory_id in directory_ids ] return sorted(directories, key=lambda x: x.launch_time) def get_directory_limits(self): """Return hard-coded limits for the directories.""" counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0} for directory in self.directories.values(): if directory.directory_type == "SimpleAD": counts["SimpleAD"] += 1 elif directory.directory_type in [ "MicrosoftAD", "SharedMicrosoftAD" ]: counts["MicrosoftAD"] += 1 elif directory.directory_type == "ADConnector": counts["ConnectedAD"] += 1 return { "CloudOnlyDirectoriesLimit": Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyDirectoriesCurrentCount": counts["SimpleAD"], "CloudOnlyDirectoriesLimitReached": counts["SimpleAD"] == Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyMicrosoftADLimit": Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "CloudOnlyMicrosoftADCurrentCount": counts["MicrosoftAD"], "CloudOnlyMicrosoftADLimitReached": counts["MicrosoftAD"] == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "ConnectedDirectoriesLimit": Directory.CONNECTED_DIRECTORIES_LIMIT, "ConnectedDirectoriesCurrentCount": counts["ConnectedAD"], "ConnectedDirectoriesLimitReached": counts["ConnectedAD"] == Directory.CONNECTED_DIRECTORIES_LIMIT, } def add_tags_to_resource(self, resource_id, tags): """Add or overwrite one or more tags for specified directory.""" self._validate_directory_id(resource_id) errmsg = self.tagger.validate_tags(tags) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise TagLimitExceededException("Tag limit exceeded") self.tagger.tag_resource(resource_id, tags) def remove_tags_from_resource(self, resource_id, tag_keys): """Removes tags from a directory.""" self._validate_directory_id(resource_id) self.tagger.untag_resource_using_names(resource_id, tag_keys) @paginate(pagination_model=PAGINATION_MODEL) def list_tags_for_resource( self, resource_id, next_token=None, limit=None, ): # pylint: disable=unused-argument """List all tags on a directory.""" self._validate_directory_id(resource_id) return self.tagger.list_tags_for_resource(resource_id).get("Tags")
class Route53ResolverBackend(BaseBackend): """Implementation of Route53Resolver APIs.""" def __init__(self, region_name, account_id): super().__init__(region_name, account_id) self.resolver_endpoints = {} # Key is self-generated ID (endpoint_id) self.resolver_rules = {} # Key is self-generated ID (rule_id) self.resolver_rule_associations = {} # Key is resolver_rule_association_id) self.tagger = TaggingService() @staticmethod def default_vpc_endpoint_service(service_region, zones): """List of dicts representing default VPC endpoints for this service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "route53resolver" ) def associate_resolver_rule(self, region, resolver_rule_id, name, vpc_id): validate_args( [("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)] ) associations = [ x for x in self.resolver_rule_associations.values() if x.region == region ] if len(associations) > ResolverRuleAssociation.MAX_RULE_ASSOCIATIONS_PER_REGION: # This error message was not verified to be the same for AWS. raise LimitExceededException( f"Account '{get_account_id()}' has exceeded 'max-rule-association'" ) if resolver_rule_id not in self.resolver_rules: raise ResourceNotFoundException( f"Resolver rule with ID '{resolver_rule_id}' does not exist." ) vpcs = ec2_backends[region].describe_vpcs() if vpc_id not in [x.id for x in vpcs]: raise InvalidParameterException(f"The vpc ID '{vpc_id}' does not exist") # Can't duplicate resolver rule, vpc id associations. for association in self.resolver_rule_associations.values(): if ( resolver_rule_id == association.resolver_rule_id and vpc_id == association.vpc_id ): raise InvalidRequestException( f"Cannot associate rules with same domain name with same " f"VPC. Conflict with resolver rule '{resolver_rule_id}'" ) rule_association_id = f"rslvr-rrassoc-{get_random_hex(17)}" rule_association = ResolverRuleAssociation( region, rule_association_id, resolver_rule_id, vpc_id, name ) self.resolver_rule_associations[rule_association_id] = rule_association return rule_association @staticmethod def _verify_subnet_ips(region, ip_addresses, initial=True): """Perform additional checks on the IPAddresses. NOTE: This does not include IPv6 addresses. """ # only initial endpoint creation requires atleast two ip addresses if initial: if len(ip_addresses) < 2: raise InvalidRequestException( "Resolver endpoint needs to have at least 2 IP addresses" ) subnets = defaultdict(set) for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]: try: subnet_info = ec2_backends[region].get_all_subnets( subnet_ids=[subnet_id] )[0] except InvalidSubnetIdError as exc: raise InvalidParameterException( f"The subnet ID '{subnet_id}' does not exist" ) from exc # IP in IPv4 CIDR range and not reserved? if ip_address(ip_addr) in subnet_info.reserved_ips or ip_address( ip_addr ) not in ip_network(subnet_info.cidr_block): raise InvalidRequestException( f"IP address '{ip_addr}' is either not in subnet " f"'{subnet_id}' CIDR range or is reserved" ) if ip_addr in subnets[subnet_id]: raise ResourceExistsException( f"The IP address '{ip_addr}' in subnet '{subnet_id}' is already in use" ) subnets[subnet_id].add(ip_addr) @staticmethod def _verify_security_group_ids(region, security_group_ids): """Perform additional checks on the security groups.""" if len(security_group_ids) > 10: raise InvalidParameterException("Maximum of 10 security groups are allowed") for group_id in security_group_ids: if not group_id.startswith("sg-"): raise InvalidParameterException( f"Malformed security group ID: Invalid id: '{group_id}' " f"(expecting 'sg-...')" ) try: ec2_backends[region].describe_security_groups(group_ids=[group_id]) except InvalidSecurityGroupNotFoundError as exc: raise ResourceNotFoundException( f"The security group '{group_id}' does not exist" ) from exc def create_resolver_endpoint( self, region, creator_request_id, name, security_group_ids, direction, ip_addresses, tags, ): # pylint: disable=too-many-arguments """ Return description for a newly created resolver endpoint. NOTE: IPv6 IPs are currently not being filtered when calculating the create_resolver_endpoint() IpAddresses. """ validate_args( [ ("creatorRequestId", creator_request_id), ("direction", direction), ("ipAddresses", ip_addresses), ("name", name), ("securityGroupIds", security_group_ids), ("ipAddresses.subnetId", ip_addresses), ] ) errmsg = self.tagger.validate_tags( tags or [], limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT ) if errmsg: raise TagValidationException(errmsg) endpoints = [x for x in self.resolver_endpoints.values() if x.region == region] if len(endpoints) > ResolverEndpoint.MAX_ENDPOINTS_PER_REGION: raise LimitExceededException( f"Account '{get_account_id()}' has exceeded 'max-endpoints'" ) for x in ip_addresses: if not x.get("Ip"): subnet_info = ec2_backends[region].get_all_subnets( subnet_ids=[x["SubnetId"]] )[0] x["Ip"] = subnet_info.get_available_subnet_ip(self) self._verify_subnet_ips(region, ip_addresses) self._verify_security_group_ids(region, security_group_ids) if creator_request_id in [ x.creator_request_id for x in self.resolver_endpoints.values() ]: raise ResourceExistsException( f"Resolver endpoint with creator request ID " f"'{creator_request_id}' already exists" ) endpoint_id = ( f"rslvr-{'in' if direction == 'INBOUND' else 'out'}-{get_random_hex(17)}" ) resolver_endpoint = ResolverEndpoint( region, endpoint_id, creator_request_id, security_group_ids, direction, ip_addresses, name, ) self.resolver_endpoints[endpoint_id] = resolver_endpoint self.tagger.tag_resource(resolver_endpoint.arn, tags or []) return resolver_endpoint def create_resolver_rule( self, region, creator_request_id, name, rule_type, domain_name, target_ips, resolver_endpoint_id, tags, ): # pylint: disable=too-many-arguments """Return description for a newly created resolver rule.""" validate_args( [ ("creatorRequestId", creator_request_id), ("ruleType", rule_type), ("domainName", domain_name), ("name", name), *[("targetIps.port", x) for x in target_ips], ("resolverEndpointId", resolver_endpoint_id), ] ) errmsg = self.tagger.validate_tags( tags or [], limit=ResolverRule.MAX_TAGS_PER_RESOLVER_RULE ) if errmsg: raise TagValidationException(errmsg) rules = [x for x in self.resolver_rules.values() if x.region == region] if len(rules) > ResolverRule.MAX_RULES_PER_REGION: # Did not verify that this is the actual error message. raise LimitExceededException( f"Account '{get_account_id()}' has exceeded 'max-rules'" ) # Per the AWS documentation and as seen with the AWS console, target # ips are only relevant when the value of Rule is FORWARD. However, # boto3 ignores this condition and so shall we. for ip_addr in [x["Ip"] for x in target_ips]: try: # boto3 fails with an InternalServiceException if IPv6 # addresses are used, which isn't helpful. if not isinstance(ip_address(ip_addr), IPv4Address): raise InvalidParameterException( f"Only IPv4 addresses may be used: '{ip_addr}'" ) except ValueError as exc: raise InvalidParameterException( f"Invalid IP address: '{ip_addr}'" ) from exc # The boto3 documentation indicates that ResolverEndpoint is # optional, as does the AWS documention. But if resolver_endpoint_id # is set to None or an empty string, it results in boto3 raising # a ParamValidationError either regarding the type or len of string. if resolver_endpoint_id: if resolver_endpoint_id not in [ x.id for x in self.resolver_endpoints.values() ]: raise ResourceNotFoundException( f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist." ) if rule_type == "SYSTEM": raise InvalidRequestException( "Cannot specify resolver endpoint ID and target IP " "for SYSTEM type resolver rule" ) if creator_request_id in [ x.creator_request_id for x in self.resolver_rules.values() ]: raise ResourceExistsException( f"Resolver rule with creator request ID " f"'{creator_request_id}' already exists" ) rule_id = f"rslvr-rr-{get_random_hex(17)}" resolver_rule = ResolverRule( region, rule_id, creator_request_id, rule_type, domain_name, target_ips, resolver_endpoint_id, name, ) self.resolver_rules[rule_id] = resolver_rule self.tagger.tag_resource(resolver_rule.arn, tags or []) return resolver_rule def _validate_resolver_endpoint_id(self, resolver_endpoint_id): """Raise an exception if the id is invalid or unknown.""" validate_args([("resolverEndpointId", resolver_endpoint_id)]) if resolver_endpoint_id not in self.resolver_endpoints: raise ResourceNotFoundException( f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist" ) def delete_resolver_endpoint(self, resolver_endpoint_id): self._validate_resolver_endpoint_id(resolver_endpoint_id) # Can't delete an endpoint if there are rules associated with it. rules = [ x.id for x in self.resolver_rules.values() if x.resolver_endpoint_id == resolver_endpoint_id ] if rules: raise InvalidRequestException( f"Cannot delete resolver endpoint unless its related resolver " f"rules are deleted. The following rules still exist for " f"this resolver endpoint: {','.join(rules)}" ) self.tagger.delete_all_tags_for_resource(resolver_endpoint_id) resolver_endpoint = self.resolver_endpoints.pop(resolver_endpoint_id) resolver_endpoint.delete_eni() resolver_endpoint.status = "DELETING" resolver_endpoint.status_message = resolver_endpoint.status_message.replace( "Successfully created", "Deleting" ) return resolver_endpoint def _validate_resolver_rule_id(self, resolver_rule_id): """Raise an exception if the id is invalid or unknown.""" validate_args([("resolverRuleId", resolver_rule_id)]) if resolver_rule_id not in self.resolver_rules: raise ResourceNotFoundException( f"Resolver rule with ID '{resolver_rule_id}' does not exist" ) def delete_resolver_rule(self, resolver_rule_id): self._validate_resolver_rule_id(resolver_rule_id) # Can't delete an rule unless VPC's are disassociated. associations = [ x.id for x in self.resolver_rule_associations.values() if x.resolver_rule_id == resolver_rule_id ] if associations: raise ResourceInUseException( "Please disassociate this resolver rule from VPC first " "before deleting" ) self.tagger.delete_all_tags_for_resource(resolver_rule_id) resolver_rule = self.resolver_rules.pop(resolver_rule_id) resolver_rule.status = "DELETING" resolver_rule.status_message = resolver_rule.status_message.replace( "Successfully created", "Deleting" ) return resolver_rule def disassociate_resolver_rule(self, resolver_rule_id, vpc_id): validate_args([("resolverRuleId", resolver_rule_id), ("vPCId", vpc_id)]) # Non-existent rule or vpc ids? if resolver_rule_id not in self.resolver_rules: raise ResourceNotFoundException( f"Resolver rule with ID '{resolver_rule_id}' does not exist" ) # Find the matching association for this rule and vpc. rule_association_id = None for association in self.resolver_rule_associations.values(): if ( resolver_rule_id == association.resolver_rule_id and vpc_id == association.vpc_id ): rule_association_id = association.id break else: raise ResourceNotFoundException( f"Resolver Rule Association between Resolver Rule " f"'{resolver_rule_id}' and VPC '{vpc_id}' does not exist" ) rule_association = self.resolver_rule_associations.pop(rule_association_id) rule_association.status = "DELETING" rule_association.status_message = "Deleting Association" return rule_association def get_resolver_endpoint(self, resolver_endpoint_id): self._validate_resolver_endpoint_id(resolver_endpoint_id) return self.resolver_endpoints[resolver_endpoint_id] def get_resolver_rule(self, resolver_rule_id): """Return info for specified resolver rule.""" self._validate_resolver_rule_id(resolver_rule_id) return self.resolver_rules[resolver_rule_id] def get_resolver_rule_association(self, resolver_rule_association_id): validate_args([("resolverRuleAssociationId", resolver_rule_association_id)]) if resolver_rule_association_id not in self.resolver_rule_associations: raise ResourceNotFoundException( f"ResolverRuleAssociation '{resolver_rule_association_id}' does not Exist" ) return self.resolver_rule_associations[resolver_rule_association_id] @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_endpoint_ip_addresses(self, resolver_endpoint_id): self._validate_resolver_endpoint_id(resolver_endpoint_id) endpoint = self.resolver_endpoints[resolver_endpoint_id] return endpoint.ip_descriptions() @staticmethod def _add_field_name_to_filter(filters): """Convert both styles of filter names to lowercase snake format. "IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count". However, "HostVPCId" doesn't fit the pattern, so that's treated special. """ for rr_filter in filters: filter_name = rr_filter["Name"] if "_" not in filter_name: if "Vpc" in filter_name: filter_name = "WRONG" elif filter_name == "HostVPCId": filter_name = "host_vpc_id" elif filter_name == "VPCId": filter_name = "vpc_id" elif filter_name in ["Type", "TYPE"]: filter_name = "rule_type" elif not filter_name.isupper(): filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name) rr_filter["Field"] = filter_name.lower() @staticmethod def _validate_filters(filters, allowed_filter_names): """Raise exception if filter names are not as expected.""" for rr_filter in filters: if rr_filter["Field"] not in allowed_filter_names: raise InvalidParameterException( f"The filter '{rr_filter['Name']}' is invalid" ) if "Values" not in rr_filter: raise InvalidParameterException( f"No values specified for filter {rr_filter['Name']}" ) @staticmethod def _matches_all_filters(entity, filters): """Return True if this entity has fields matching all the filters.""" for rr_filter in filters: field_value = getattr(entity, rr_filter["Field"]) if isinstance(field_value, list): if not set(field_value).intersection(rr_filter["Values"]): return False elif isinstance(field_value, int): if str(field_value) not in rr_filter["Values"]: return False elif field_value not in rr_filter["Values"]: return False return True @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_endpoints(self, filters): if not filters: filters = [] self._add_field_name_to_filter(filters) self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES) endpoints = [] for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name): if self._matches_all_filters(endpoint, filters): endpoints.append(endpoint) return endpoints @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_rules(self, filters): if not filters: filters = [] self._add_field_name_to_filter(filters) self._validate_filters(filters, ResolverRule.FILTER_NAMES) rules = [] for rule in sorted(self.resolver_rules.values(), key=lambda x: x.name): if self._matches_all_filters(rule, filters): rules.append(rule) return rules @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_rule_associations(self, filters): if not filters: filters = [] self._add_field_name_to_filter(filters) self._validate_filters(filters, ResolverRuleAssociation.FILTER_NAMES) rules = [] for rule in sorted( self.resolver_rule_associations.values(), key=lambda x: x.name ): if self._matches_all_filters(rule, filters): rules.append(rule) return rules def _matched_arn(self, resource_arn): """Given ARN, raise exception if there is no corresponding resource.""" for resolver_endpoint in self.resolver_endpoints.values(): if resolver_endpoint.arn == resource_arn: return for resolver_rule in self.resolver_rules.values(): if resolver_rule.arn == resource_arn: return raise ResourceNotFoundException( f"Resolver endpoint with ID '{resource_arn}' does not exist" ) @paginate(pagination_model=PAGINATION_MODEL) def list_tags_for_resource(self, resource_arn): self._matched_arn(resource_arn) return self.tagger.list_tags_for_resource(resource_arn).get("Tags") def tag_resource(self, resource_arn, tags): self._matched_arn(resource_arn) errmsg = self.tagger.validate_tags( tags, limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT ) if errmsg: raise TagValidationException(errmsg) self.tagger.tag_resource(resource_arn, tags) def untag_resource(self, resource_arn, tag_keys): self._matched_arn(resource_arn) self.tagger.untag_resource_using_names(resource_arn, tag_keys) def update_resolver_endpoint(self, resolver_endpoint_id, name): self._validate_resolver_endpoint_id(resolver_endpoint_id) validate_args([("name", name)]) resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] resolver_endpoint.update_name(name) return resolver_endpoint def associate_resolver_endpoint_ip_address( self, region, resolver_endpoint_id, ip_address ): self._validate_resolver_endpoint_id(resolver_endpoint_id) resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] if not ip_address.get("Ip"): subnet_info = ec2_backends[region].get_all_subnets( subnet_ids=[ip_address.get("SubnetId")] )[0] ip_address["Ip"] = subnet_info.get_available_subnet_ip(self) self._verify_subnet_ips(region, [ip_address], False) resolver_endpoint.associate_ip_address(ip_address) return resolver_endpoint def disassociate_resolver_endpoint_ip_address( self, resolver_endpoint_id, ip_address ): self._validate_resolver_endpoint_id(resolver_endpoint_id) resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] if not (ip_address.get("Ip") or ip_address.get("IpId")): raise InvalidRequestException( "[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address." ) resolver_endpoint.disassociate_ip_address(ip_address) return resolver_endpoint
class DirectoryServiceBackend(BaseBackend): """Implementation of DirectoryService APIs.""" def __init__(self, region_name=None): self.region_name = region_name self.directories = {} self.tagger = TaggingService() def reset(self): """Re-initialize all attributes for this instance.""" region_name = self.region_name self.__dict__ = {} self.__init__(region_name) @staticmethod def default_vpc_endpoint_service(service_region, zones): """List of dicts representing default VPC endpoints for this service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "ds" ) @staticmethod def _validate_create_directory_args( name, passwd, size, vpc_settings, description, short_name, ): # pylint: disable=too-many-arguments """Raise exception if create_directory() args don't meet constraints. The error messages are accumulated before the exception is raised. """ error_tuples = [] passwd_pattern = ( r"(?=^.{8,64}$)((?=.*\d)(?=.*[A-Z])(?=.*[a-z])|" r"(?=.*\d)(?=.*[^A-Za-z0-9\s])(?=.*[a-z])|" r"(?=.*[^A-Za-z0-9\s])(?=.*[A-Z])(?=.*[a-z])|" r"(?=.*\d)(?=.*[A-Z])(?=.*[^A-Za-z0-9\s]))^.*" ) if not re.match(passwd_pattern, passwd): # Can't have an odd number of backslashes in a literal. json_pattern = passwd_pattern.replace("\\", r"\\") error_tuples.append( ( "password", passwd, fr"satisfy regular expression pattern: {json_pattern}", ) ) if size.lower() not in ["small", "large"]: error_tuples.append( ("size", size, "satisfy enum value set: [Small, Large]") ) name_pattern = r"^([a-zA-Z0-9]+[\\.-])+([a-zA-Z0-9])+$" if not re.match(name_pattern, name): error_tuples.append( ("name", name, fr"satisfy regular expression pattern: {name_pattern}") ) subnet_id_pattern = r"^(subnet-[0-9a-f]{8}|subnet-[0-9a-f]{17})$" for subnet in vpc_settings["SubnetIds"]: if not re.match(subnet_id_pattern, subnet): error_tuples.append( ( "vpcSettings.subnetIds", subnet, fr"satisfy regular expression pattern: {subnet_id_pattern}", ) ) if description and len(description) > 128: error_tuples.append( ("description", description, "have length less than or equal to 128") ) short_name_pattern = r'^[^\/:*?"<>|.]+[^\/:*?"<>|]*$' if short_name and not re.match(short_name_pattern, short_name): json_pattern = short_name_pattern.replace("\\", r"\\").replace('"', r"\"") error_tuples.append( ( "shortName", short_name, fr"satisfy regular expression pattern: {json_pattern}", ) ) if error_tuples: raise DsValidationException(error_tuples) @staticmethod def _validate_vpc_setting_values(region, vpc_settings): """Raise exception if vpc_settings are invalid. If settings are valid, add AvailabilityZones to vpc_settings. """ if len(vpc_settings["SubnetIds"]) != 2: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones." ) from moto.ec2 import ec2_backends # pylint: disable=import-outside-toplevel # Subnet IDs are checked before the VPC ID. The Subnet IDs must # be valid and in different availability zones. try: subnets = ec2_backends[region].get_all_subnets( subnet_ids=vpc_settings["SubnetIds"] ) except InvalidSubnetIdError as exc: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones." ) from exc regions = [subnet.availability_zone for subnet in subnets] if regions[0] == regions[1]: raise ClientException( "Invalid subnet ID(s). The two subnets must be in " "different Availability Zones." ) vpcs = ec2_backends[region].describe_vpcs() if vpc_settings["VpcId"] not in [x.id for x in vpcs]: raise ClientException("Invalid VPC ID.") vpc_settings["AvailabilityZones"] = regions def create_directory( self, region, name, short_name, password, description, size, vpc_settings, tags ): # pylint: disable=too-many-arguments """Create a fake Simple Ad Directory.""" if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created" ) # botocore doesn't look for missing vpc_settings, but boto3 does. if not vpc_settings: raise InvalidParameterException("VpcSettings must be specified.") self._validate_create_directory_args( name, password, size, vpc_settings, description, short_name, ) self._validate_vpc_setting_values(region, vpc_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( name, password, size, vpc_settings, directory_type="SimpleAD", short_name=short_name, description=description, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def _validate_directory_id(self, directory_id): """Raise an exception if the directory id is invalid or unknown.""" # Validation of ID takes precedence over a check for its existence. id_pattern = r"^d-[0-9a-f]{10}$" if not re.match(id_pattern, directory_id): raise DsValidationException( [ ( "directoryId", directory_id, fr"satisfy regular expression pattern: {id_pattern}", ) ] ) if directory_id not in self.directories: raise EntityDoesNotExistException( f"Directory {directory_id} does not exist" ) def delete_directory(self, directory_id): """Delete directory with the matching ID.""" self._validate_directory_id(directory_id) self.tagger.delete_all_tags_for_resource(directory_id) self.directories.pop(directory_id) return directory_id @paginate(pagination_model=PAGINATION_MODEL) def describe_directories( self, directory_ids=None, next_token=None, limit=0 ): # pylint: disable=unused-argument """Return info on all directories or directories with matching IDs.""" for directory_id in directory_ids or self.directories: self._validate_directory_id(directory_id) directories = list(self.directories.values()) if directory_ids: directories = [x for x in directories if x.directory_id in directory_ids] return sorted(directories, key=lambda x: x.launch_time) def get_directory_limits(self): """Return hard-coded limits for the directories.""" counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0} for directory in self.directories.values(): if directory.directory_type == "SimpleAD": counts["SimpleAD"] += 1 elif directory.directory_type in ["MicrosoftAD", "SharedMicrosoftAD"]: counts["MicrosoftAD"] += 1 elif directory.directory_type == "ADConnector": counts["ConnectedAD"] += 1 return { "CloudOnlyDirectoriesLimit": Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyDirectoriesCurrentCount": counts["SimpleAD"], "CloudOnlyDirectoriesLimitReached": counts["SimpleAD"] == Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyMicrosoftADLimit": Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "CloudOnlyMicrosoftADCurrentCount": counts["MicrosoftAD"], "CloudOnlyMicrosoftADLimitReached": counts["MicrosoftAD"] == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "ConnectedDirectoriesLimit": Directory.CONNECTED_DIRECTORIES_LIMIT, "ConnectedDirectoriesCurrentCount": counts["ConnectedAD"], "ConnectedDirectoriesLimitReached": counts["ConnectedAD"] == Directory.CONNECTED_DIRECTORIES_LIMIT, } def add_tags_to_resource(self, resource_id, tags): """Add or overwrite one or more tags for specified directory.""" self._validate_directory_id(resource_id) errmsg = self.tagger.validate_tags(tags) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise TagLimitExceededException("Tag limit exceeded") self.tagger.tag_resource(resource_id, tags) def remove_tags_from_resource(self, resource_id, tag_keys): """Removes tags from a directory.""" self._validate_directory_id(resource_id) self.tagger.untag_resource_using_names(resource_id, tag_keys) @paginate(pagination_model=PAGINATION_MODEL) def list_tags_for_resource( self, resource_id, next_token=None, limit=None, ): # pylint: disable=unused-argument """List all tags on a directory.""" self._validate_directory_id(resource_id) return self.tagger.list_tags_for_resource(resource_id).get("Tags")
class FirehoseBackend(BaseBackend): """Implementation of Firehose APIs.""" def __init__(self, region_name=None): self.region_name = region_name self.delivery_streams = {} self.tagger = TaggingService() def reset(self): """Re-initializes all attributes for this instance.""" region_name = self.region_name self.__dict__ = {} self.__init__(region_name) @staticmethod def default_vpc_endpoint_service(service_region, zones): """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "firehose", special_service_name="kinesis-firehose") def create_delivery_stream( self, region, delivery_stream_name, delivery_stream_type, kinesis_stream_source_configuration, delivery_stream_encryption_configuration_input, s3_destination_configuration, extended_s3_destination_configuration, redshift_destination_configuration, elasticsearch_destination_configuration, splunk_destination_configuration, http_endpoint_destination_configuration, tags, ): # pylint: disable=too-many-arguments,too-many-locals,unused-argument """Create a Kinesis Data Firehose delivery stream.""" (destination_name, destination_config) = find_destination_config_in_args(locals()) if delivery_stream_name in self.delivery_streams: raise ResourceInUseException( f"Firehose {delivery_stream_name} under accountId {get_account_id()} " f"already exists") if len(self.delivery_streams) == DeliveryStream.MAX_STREAMS_PER_REGION: raise LimitExceededException( f"You have already consumed your firehose quota of " f"{DeliveryStream.MAX_STREAMS_PER_REGION} hoses. Firehose " f"names: {list(self.delivery_streams.keys())}") # Rule out situations that are not yet implemented. if delivery_stream_encryption_configuration_input: warnings.warn( "A delivery stream with server-side encryption enabled is not " "yet implemented") if destination_name == "Splunk": warnings.warn( "A Splunk destination delivery stream is not yet implemented") if (kinesis_stream_source_configuration and delivery_stream_type != "KinesisStreamAsSource"): raise InvalidArgumentException( "KinesisSourceStreamConfig is only applicable for " "KinesisStreamAsSource stream type") # Validate the tags before proceeding. errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if tags and len(tags) > MAX_TAGS_PER_DELIVERY_STREAM: raise ValidationException( f"1 validation error detected: Value '{tags}' at 'tags' " f"failed to satisify contstraint: Member must have length " f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}") # Create a DeliveryStream instance that will be stored and indexed # by delivery stream name. This instance will update the state and # create the ARN. delivery_stream = DeliveryStream( region, delivery_stream_name, delivery_stream_type, kinesis_stream_source_configuration, destination_name, destination_config, ) self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags or []) self.delivery_streams[delivery_stream_name] = delivery_stream return self.delivery_streams[delivery_stream_name].delivery_stream_arn def delete_delivery_stream(self, delivery_stream_name, allow_force_delete=False): # pylint: disable=unused-argument """Delete a delivery stream and its data. AllowForceDelete option is ignored as we only superficially apply state. """ delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") self.tagger.delete_all_tags_for_resource( delivery_stream.delivery_stream_arn) delivery_stream.delivery_stream_status = "DELETING" self.delivery_streams.pop(delivery_stream_name) def describe_delivery_stream(self, delivery_stream_name, limit, exclusive_start_destination_id): # pylint: disable=unused-argument """Return description of specified delivery stream and its status. Note: the 'limit' and 'exclusive_start_destination_id' parameters are not currently processed/implemented. """ delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") result = {"DeliveryStreamDescription": {"HasMoreDestinations": False}} for attribute, attribute_value in vars(delivery_stream).items(): if not attribute_value: continue # Convert from attribute's snake case to camel case for outgoing # JSON. name = "".join([x.capitalize() for x in attribute.split("_")]) # Fooey ... always an exception to the rule: if name == "DeliveryStreamArn": name = "DeliveryStreamARN" if name != "Destinations": if name == "Source": result["DeliveryStreamDescription"][name] = { "KinesisStreamSourceDescription": attribute_value } else: result["DeliveryStreamDescription"][name] = attribute_value continue result["DeliveryStreamDescription"]["Destinations"] = [] for destination in attribute_value: description = {} for key, value in destination.items(): if key == "destination_id": description["DestinationId"] = value else: description[f"{key}DestinationDescription"] = value result["DeliveryStreamDescription"]["Destinations"].append( description) return result def list_delivery_streams(self, limit, delivery_stream_type, exclusive_start_delivery_stream_name): """Return list of delivery streams in alphabetic order of names.""" result = {"DeliveryStreamNames": [], "HasMoreDeliveryStreams": False} if not self.delivery_streams: return result # If delivery_stream_type is specified, filter out any stream that's # not of that type. stream_list = self.delivery_streams.keys() if delivery_stream_type: stream_list = [ x for x in stream_list if self.delivery_streams[x].delivery_stream_type == delivery_stream_type ] # The list is sorted alphabetically, not alphanumerically. sorted_list = sorted(stream_list) # Determine the limit or number of names to return in the list. limit = limit or DeliveryStream.MAX_STREAMS_PER_REGION # If a starting delivery stream name is given, find the index into # the sorted list, then add one to get the name following it. If the # exclusive_start_delivery_stream_name doesn't exist, it's ignored. start = 0 if exclusive_start_delivery_stream_name: if self.delivery_streams.get(exclusive_start_delivery_stream_name): start = sorted_list.index( exclusive_start_delivery_stream_name) + 1 result["DeliveryStreamNames"] = sorted_list[start:start + limit] if len(sorted_list) > (start + limit): result["HasMoreDeliveryStreams"] = True return result def list_tags_for_delivery_stream(self, delivery_stream_name, exclusive_start_tag_key, limit): """Return list of tags.""" result = {"Tags": [], "HasMoreTags": False} delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") tags = self.tagger.list_tags_for_resource( delivery_stream.delivery_stream_arn)["Tags"] keys = self.tagger.extract_tag_names(tags) # If a starting tag is given and can be found, find the index into # tags, then add one to get the tag following it. start = 0 if exclusive_start_tag_key: if exclusive_start_tag_key in keys: start = keys.index(exclusive_start_tag_key) + 1 limit = limit or MAX_TAGS_PER_DELIVERY_STREAM result["Tags"] = tags[start:start + limit] if len(tags) > (start + limit): result["HasMoreTags"] = True return result def put_record(self, delivery_stream_name, record): """Write a single data record into a Kinesis Data firehose stream.""" result = self.put_record_batch(delivery_stream_name, [record]) return { "RecordId": result["RequestResponses"][0]["RecordId"], "Encrypted": False, } @staticmethod def put_http_records(http_destination, records): """Put records to a HTTP destination.""" # Mostly copied from localstack url = http_destination["EndpointConfiguration"]["Url"] headers = {"Content-Type": "application/json"} record_to_send = { "requestId": str(uuid4()), "timestamp": int(time()), "records": [{ "data": record["Data"] } for record in records], } try: requests.post(url, json=record_to_send, headers=headers) except Exception as exc: # This could be better ... raise RuntimeError( "Firehose PutRecord(Batch) to HTTP destination failed" ) from exc return [{"RecordId": str(uuid4())} for _ in range(len(records))] @staticmethod def _format_s3_object_path(delivery_stream_name, version_id, prefix): """Return a S3 object path in the expected format.""" # Taken from LocalStack's firehose logic, with minor changes. # See https://docs.aws.amazon.com/firehose/latest/dev/basic-deliver.html#s3-object-name # Path prefix pattern: myApp/YYYY/MM/DD/HH/ # Object name pattern: # DeliveryStreamName-DeliveryStreamVersion-YYYY-MM-DD-HH-MM-SS-RandomString prefix = f"{prefix}{'' if prefix.endswith('/') else '/'}" now = datetime.utcnow() return (f"{prefix}{now.strftime('%Y/%m/%d/%H')}/" f"{delivery_stream_name}-{version_id}-" f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{str(uuid4())}") def put_s3_records(self, delivery_stream_name, version_id, s3_destination, records): """Put records to a ExtendedS3 or S3 destination.""" # Taken from LocalStack's firehose logic, with minor changes. bucket_name = s3_destination["BucketARN"].split(":")[-1] prefix = s3_destination.get("Prefix", "") object_path = self._format_s3_object_path(delivery_stream_name, version_id, prefix) batched_data = b"".join([b64decode(r["Data"]) for r in records]) try: s3_backend.put_object(bucket_name, object_path, batched_data) except Exception as exc: # This could be better ... raise RuntimeError( "Firehose PutRecord(Batch to S3 destination failed") from exc return [{"RecordId": str(uuid4())} for _ in range(len(records))] def put_record_batch(self, delivery_stream_name, records): """Write multiple data records into a Kinesis Data firehose stream.""" delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") request_responses = [] for destination in delivery_stream.destinations: if "ExtendedS3" in destination: # ExtendedS3 will be handled like S3,but in the future # this will probably need to be revisited. This destination # must be listed before S3 otherwise both destinations will # be processed instead of just ExtendedS3. request_responses = self.put_s3_records( delivery_stream_name, delivery_stream.version_id, destination["ExtendedS3"], records, ) elif "S3" in destination: request_responses = self.put_s3_records( delivery_stream_name, delivery_stream.version_id, destination["S3"], records, ) elif "HttpEndpoint" in destination: request_responses = self.put_http_records( destination["HttpEndpoint"], records) elif "Elasticsearch" in destination or "Redshift" in destination: # This isn't implmented as these services aren't implemented, # so ignore the data, but return a "proper" response. request_responses = [{ "RecordId": str(uuid4()) } for _ in range(len(records))] return { "FailedPutCount": 0, "Encrypted": False, "RequestResponses": request_responses, } def tag_delivery_stream(self, delivery_stream_name, tags): """Add/update tags for specified delivery stream.""" delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") if len(tags) > MAX_TAGS_PER_DELIVERY_STREAM: raise ValidationException( f"1 validation error detected: Value '{tags}' at 'tags' " f"failed to satisify contstraint: Member must have length " f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}") errmsg = self.tagger.validate_tags(tags) if errmsg: raise ValidationException(errmsg) self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags) def untag_delivery_stream(self, delivery_stream_name, tag_keys): """Removes tags from specified delivery stream.""" delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") # If a tag key doesn't exist for the stream, boto3 ignores it. self.tagger.untag_resource_using_names( delivery_stream.delivery_stream_arn, tag_keys) def update_destination( self, delivery_stream_name, current_delivery_stream_version_id, destination_id, s3_destination_update, extended_s3_destination_update, s3_backup_mode, redshift_destination_update, elasticsearch_destination_update, splunk_destination_update, http_endpoint_destination_update, ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals """Updates specified destination of specified delivery stream.""" (destination_name, destination_config) = find_destination_config_in_args(locals()) delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under accountId " f"{get_account_id()} not found.") if destination_name == "Splunk": warnings.warn( "A Splunk destination delivery stream is not yet implemented") if delivery_stream.version_id != current_delivery_stream_version_id: raise ConcurrentModificationException( f"Cannot update firehose: {delivery_stream_name} since the " f"current version id: {delivery_stream.version_id} and " f"specified version id: {current_delivery_stream_version_id} " f"do not match") destination = {} destination_idx = 0 for destination in delivery_stream.destinations: if destination["destination_id"] == destination_id: break destination_idx += 1 else: raise InvalidArgumentException( "Destination Id {destination_id} not found") # Switching between Amazon ES and other services is not supported. # For an Amazon ES destination, you can only update to another Amazon # ES destination. Same with HTTP. Didn't test Splunk. if (destination_name == "Elasticsearch" and "Elasticsearch" not in destination) or (destination_name == "HttpEndpoint" and "HttpEndpoint" not in destination): raise InvalidArgumentException( f"Changing the destination type to or from {destination_name} " f"is not supported at this time.") # If this is a different type of destination configuration, # the existing configuration is reset first. if destination_name in destination: delivery_stream.destinations[destination_idx][ destination_name].update(destination_config) else: delivery_stream.destinations[destination_idx] = { "destination_id": destination_id, destination_name: destination_config, } # Once S3 is updated to an ExtendedS3 destination, both remain in # the destination. That means when one is updated, the other needs # to be updated as well. The problem is that they don't have the # same fields. if destination_name == "ExtendedS3": delivery_stream.destinations[destination_idx][ "S3"] = create_s3_destination_config(destination_config) elif destination_name == "S3" and "ExtendedS3" in destination: destination["ExtendedS3"] = { k: v for k, v in destination["S3"].items() if k in destination["ExtendedS3"] } # Increment version number and update the timestamp. delivery_stream.version_id = str( int(current_delivery_stream_version_id) + 1) delivery_stream.last_update_timestamp = datetime.now( timezone.utc).isoformat() # Unimplemented: processing of the "S3BackupMode" parameter. Per the # documentation: "You can update a delivery stream to enable Amazon # S3 backup if it is disabled. If backup is enabled, you can't update # the delivery stream to disable it." def lookup_name_from_arn(self, arn): """Given an ARN, return the associated delivery stream name.""" return self.delivery_streams.get(arn.split("/")[-1]) def send_log_event( self, delivery_stream_arn, filter_name, log_group_name, log_stream_name, log_events, ): # pylint: disable=too-many-arguments """Send log events to a S3 bucket after encoding and gzipping it.""" data = { "logEvents": log_events, "logGroup": log_group_name, "logStream": log_stream_name, "messageType": "DATA_MESSAGE", "owner": get_account_id(), "subscriptionFilters": [filter_name], } output = io.BytesIO() with GzipFile(fileobj=output, mode="w") as fhandle: fhandle.write( json.dumps(data, separators=(",", ":")).encode("utf-8")) gzipped_payload = b64encode(output.getvalue()).decode("utf-8") delivery_stream = self.lookup_name_from_arn(delivery_stream_arn) self.put_s3_records( delivery_stream.delivery_stream_name, delivery_stream.version_id, delivery_stream.destinations[0]["S3"], [{ "Data": gzipped_payload }], )
def test_validate_tags(): """Unit tests for validate_tags().""" svc = TaggingService() # Key with invalid characters. errors = svc.validate_tags([{"Key": "foo!", "Value": "bar"}]) assert ( "Value 'foo!' at 'tags.1.member.key' failed to satisfy constraint: " "Member must satisfy regular expression pattern") in errors # Value with invalid characters. errors = svc.validate_tags([{"Key": "foo", "Value": "bar!"}]) assert ( "Value 'bar!' at 'tags.1.member.value' failed to satisfy " "constraint: Member must satisfy regular expression pattern") in errors # Key too long. errors = svc.validate_tags([{ "Key": ("12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "123456789012345678901234567890"), "Value": "foo", }]) assert ("at 'tags.1.member.key' failed to satisfy constraint: Member must " "have length less than or equal to 128") in errors # Value too long. errors = svc.validate_tags([ { "Key": "foo", "Value": ("12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "1234567890"), }, ]) assert ("at 'tags.1.member.value' failed to satisfy constraint: Member " "must have length less than or equal to 256") in errors # Compound errors. errors = svc.validate_tags([ { "Key": ("12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "123456789012345678901234567890"), "Value": ("12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "12345678901234567890123456789012345678901234567890" "1234567890"), }, { "Key": "foo!", "Value": "bar!" }, ]) assert "4 validation errors detected" in errors assert ("at 'tags.1.member.key' failed to satisfy constraint: Member must " "have length less than or equal to 128") in errors assert ("at 'tags.1.member.value' failed to satisfy constraint: Member " "must have length less than or equal to 256") in errors assert ( "Value 'foo!' at 'tags.2.member.key' failed to satisfy constraint: " "Member must satisfy regular expression pattern") in errors assert ( "Value 'bar!' at 'tags.2.member.value' failed to satisfy " "constraint: Member must satisfy regular expression pattern") in errors