def test_delete_all_tags_for_resource(): svc = TaggingService() tags = [{"Key": "key_key", "Value": "value_value"}] tags2 = [{"Key": "key_key2", "Value": "value_value2"}] svc.tag_resource("arn", tags) svc.tag_resource("arn", tags2) svc.delete_all_tags_for_resource("arn") result = svc.list_tags_for_resource("arn") {"Tags": []}.should.be.equal(result)
class DirectoryServiceBackend(BaseBackend): """Implementation of DirectoryService APIs.""" def __init__(self, region_name=None): self.region_name = region_name self.directories = {} self.tagger = TaggingService() def reset(self): """Re-initialize all attributes for this instance.""" region_name = self.region_name self.__dict__ = {} self.__init__(region_name) @staticmethod def default_vpc_endpoint_service(service_region, zones): """List of dicts representing default VPC endpoints for this service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "ds") @staticmethod def _verify_subnets(region, vpc_settings): """Verify subnets are valid, else raise an exception. If settings are valid, add AvailabilityZones to vpc_settings. """ if len(vpc_settings["SubnetIds"]) != 2: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones.") # Subnet IDs are checked before the VPC ID. The Subnet IDs must # be valid and in different availability zones. try: subnets = ec2_backends[region].get_all_subnets( subnet_ids=vpc_settings["SubnetIds"]) except InvalidSubnetIdError as exc: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones.") from exc regions = [subnet.availability_zone for subnet in subnets] if regions[0] == regions[1]: raise ClientException( "Invalid subnet ID(s). The two subnets must be in " "different Availability Zones.") vpcs = ec2_backends[region].describe_vpcs() if vpc_settings["VpcId"] not in [x.id for x in vpcs]: raise ClientException("Invalid VPC ID.") vpc_settings["AvailabilityZones"] = regions def connect_directory( self, region, name, short_name, password, description, size, connect_settings, tags, ): # pylint: disable=too-many-arguments """Create a fake AD Connector.""" if len(self.directories) > Directory.CONNECTED_DIRECTORIES_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CONNECTED_DIRECTORIES_LIMIT} directories may be created" ) validate_args([ ("password", password), ("size", size), ("name", name), ("description", description), ("shortName", short_name), ( "connectSettings.vpcSettings.subnetIds", connect_settings["SubnetIds"], ), ( "connectSettings.customerUserName", connect_settings["CustomerUserName"], ), ("connectSettings.customerDnsIps", connect_settings["CustomerDnsIps"]), ]) # ConnectSettings and VpcSettings both have a VpcId and Subnets. self._verify_subnets(region, connect_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( region, name, password, "ADConnector", size=size, connect_settings=connect_settings, short_name=short_name, description=description, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def create_directory(self, region, name, short_name, password, description, size, vpc_settings, tags): # pylint: disable=too-many-arguments """Create a fake Simple Ad Directory.""" if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created" ) # botocore doesn't look for missing vpc_settings, but boto3 does. if not vpc_settings: raise InvalidParameterException("VpcSettings must be specified.") validate_args([ ("password", password), ("size", size), ("name", name), ("description", description), ("shortName", short_name), ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]), ]) self._verify_subnets(region, vpc_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( region, name, password, "SimpleAD", size=size, vpc_settings=vpc_settings, short_name=short_name, description=description, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def _validate_directory_id(self, directory_id): """Raise an exception if the directory id is invalid or unknown.""" # Validation of ID takes precedence over a check for its existence. validate_args([("directoryId", directory_id)]) if directory_id not in self.directories: raise EntityDoesNotExistException( f"Directory {directory_id} does not exist") def create_alias(self, directory_id, alias): """Create and assign an alias to a directory.""" self._validate_directory_id(directory_id) # The default alias name is the same as the directory name. Check # whether this directory was already given an alias. directory = self.directories[directory_id] if directory.alias != directory_id: raise InvalidParameterException( "The directory in the request already has an alias. That " "alias must be deleted before a new alias can be created.") # Is the alias already in use? if alias in [x.alias for x in self.directories.values()]: raise EntityAlreadyExistsException( f"Alias '{alias}' already exists.") validate_args([("alias", alias)]) directory.update_alias(alias) return {"DirectoryId": directory_id, "Alias": alias} def create_microsoft_ad( self, region, name, short_name, password, description, vpc_settings, edition, tags, ): # pylint: disable=too-many-arguments """Create a fake Microsoft Ad Directory.""" if len(self.directories) > Directory.CLOUDONLY_MICROSOFT_AD_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CLOUDONLY_MICROSOFT_AD_LIMIT} directories may be created" ) # boto3 looks for missing vpc_settings for create_microsoft_ad(). validate_args([ ("password", password), ("edition", edition), ("name", name), ("description", description), ("shortName", short_name), ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]), ]) self._verify_subnets(region, vpc_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( region, name, password, "MicrosoftAD", vpc_settings=vpc_settings, short_name=short_name, description=description, edition=edition, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def delete_directory(self, directory_id): """Delete directory with the matching ID.""" self._validate_directory_id(directory_id) self.tagger.delete_all_tags_for_resource(directory_id) self.directories.pop(directory_id) return directory_id def disable_sso(self, directory_id, username=None, password=None): """Disable single-sign on for a directory.""" self._validate_directory_id(directory_id) validate_args([("ssoPassword", password), ("userName", username)]) directory = self.directories[directory_id] directory.enable_sso(False) def enable_sso(self, directory_id, username=None, password=None): """Enable single-sign on for a directory.""" self._validate_directory_id(directory_id) validate_args([("ssoPassword", password), ("userName", username)]) directory = self.directories[directory_id] if directory.alias == directory_id: raise ClientException( f"An alias is required before enabling SSO. DomainId={directory_id}" ) directory = self.directories[directory_id] directory.enable_sso(True) @paginate(pagination_model=PAGINATION_MODEL) def describe_directories(self, directory_ids=None, next_token=None, limit=0): # pylint: disable=unused-argument """Return info on all directories or directories with matching IDs.""" for directory_id in directory_ids or self.directories: self._validate_directory_id(directory_id) directories = list(self.directories.values()) if directory_ids: directories = [ x for x in directories if x.directory_id in directory_ids ] return sorted(directories, key=lambda x: x.launch_time) def get_directory_limits(self): """Return hard-coded limits for the directories.""" counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0} for directory in self.directories.values(): if directory.directory_type == "SimpleAD": counts["SimpleAD"] += 1 elif directory.directory_type in [ "MicrosoftAD", "SharedMicrosoftAD" ]: counts["MicrosoftAD"] += 1 elif directory.directory_type == "ADConnector": counts["ConnectedAD"] += 1 return { "CloudOnlyDirectoriesLimit": Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyDirectoriesCurrentCount": counts["SimpleAD"], "CloudOnlyDirectoriesLimitReached": counts["SimpleAD"] == Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyMicrosoftADLimit": Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "CloudOnlyMicrosoftADCurrentCount": counts["MicrosoftAD"], "CloudOnlyMicrosoftADLimitReached": counts["MicrosoftAD"] == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "ConnectedDirectoriesLimit": Directory.CONNECTED_DIRECTORIES_LIMIT, "ConnectedDirectoriesCurrentCount": counts["ConnectedAD"], "ConnectedDirectoriesLimitReached": counts["ConnectedAD"] == Directory.CONNECTED_DIRECTORIES_LIMIT, } def add_tags_to_resource(self, resource_id, tags): """Add or overwrite one or more tags for specified directory.""" self._validate_directory_id(resource_id) errmsg = self.tagger.validate_tags(tags) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise TagLimitExceededException("Tag limit exceeded") self.tagger.tag_resource(resource_id, tags) def remove_tags_from_resource(self, resource_id, tag_keys): """Removes tags from a directory.""" self._validate_directory_id(resource_id) self.tagger.untag_resource_using_names(resource_id, tag_keys) @paginate(pagination_model=PAGINATION_MODEL) def list_tags_for_resource( self, resource_id, next_token=None, limit=None, ): # pylint: disable=unused-argument """List all tags on a directory.""" self._validate_directory_id(resource_id) return self.tagger.list_tags_for_resource(resource_id).get("Tags")
class KmsBackend(BaseBackend): def __init__(self): self.keys = {} self.key_to_aliases = defaultdict(set) self.tagger = TaggingService(keyName="TagKey", valueName="TagValue") def create_key(self, policy, key_usage, customer_master_key_spec, description, tags, region): key = Key(policy, key_usage, customer_master_key_spec, description, region) self.keys[key.id] = key if tags is not None and len(tags) > 0: self.tag_resource(key.id, tags) return key def update_key_description(self, key_id, description): key = self.keys[self.get_key_id(key_id)] key.description = description def delete_key(self, key_id): if key_id in self.keys: if key_id in self.key_to_aliases: self.key_to_aliases.pop(key_id) self.tagger.delete_all_tags_for_resource(key_id) return self.keys.pop(key_id) def describe_key(self, key_id): # allow the different methods (alias, ARN :key/, keyId, ARN alias) to # describe key not just KeyId key_id = self.get_key_id(key_id) if r"alias/" in str(key_id).lower(): key_id = self.get_key_id_from_alias(key_id.split("alias/")[1]) return self.keys[self.get_key_id(key_id)] def list_keys(self): return self.keys.values() @staticmethod def get_key_id(key_id): # Allow use of ARN as well as pure KeyId if key_id.startswith("arn:") and ":key/" in key_id: return key_id.split(":key/")[1] return key_id @staticmethod def get_alias_name(alias_name): # Allow use of ARN as well as alias name if alias_name.startswith("arn:") and ":alias/" in alias_name: return alias_name.split(":alias/")[1] return alias_name def any_id_to_key_id(self, key_id): """Go from any valid key ID to the raw key ID. Acceptable inputs: - raw key ID - key ARN - alias name - alias ARN """ key_id = self.get_alias_name(key_id) key_id = self.get_key_id(key_id) if key_id.startswith("alias/"): key_id = self.get_key_id_from_alias(key_id) return key_id def alias_exists(self, alias_name): for aliases in self.key_to_aliases.values(): if alias_name in aliases: return True return False def add_alias(self, target_key_id, alias_name): self.key_to_aliases[target_key_id].add(alias_name) def delete_alias(self, alias_name): """Delete the alias.""" for aliases in self.key_to_aliases.values(): if alias_name in aliases: aliases.remove(alias_name) def get_all_aliases(self): return self.key_to_aliases def get_key_id_from_alias(self, alias_name): for key_id, aliases in dict(self.key_to_aliases).items(): if alias_name in ",".join(aliases): return key_id return None def enable_key_rotation(self, key_id): self.keys[self.get_key_id(key_id)].key_rotation_status = True def disable_key_rotation(self, key_id): self.keys[self.get_key_id(key_id)].key_rotation_status = False def get_key_rotation_status(self, key_id): return self.keys[self.get_key_id(key_id)].key_rotation_status def put_key_policy(self, key_id, policy): self.keys[self.get_key_id(key_id)].policy = policy def get_key_policy(self, key_id): return self.keys[self.get_key_id(key_id)].policy def disable_key(self, key_id): self.keys[key_id].enabled = False self.keys[key_id].key_state = "Disabled" def enable_key(self, key_id): self.keys[key_id].enabled = True self.keys[key_id].key_state = "Enabled" def cancel_key_deletion(self, key_id): self.keys[key_id].key_state = "Disabled" self.keys[key_id].deletion_date = None def schedule_key_deletion(self, key_id, pending_window_in_days): if 7 <= pending_window_in_days <= 30: self.keys[key_id].enabled = False self.keys[key_id].key_state = "PendingDeletion" self.keys[key_id].deletion_date = datetime.now() + timedelta( days=pending_window_in_days) return unix_time(self.keys[key_id].deletion_date) def encrypt(self, key_id, plaintext, encryption_context): key_id = self.any_id_to_key_id(key_id) ciphertext_blob = encrypt( master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context, ) arn = self.keys[key_id].arn return ciphertext_blob, arn def decrypt(self, ciphertext_blob, encryption_context): plaintext, key_id = decrypt( master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context, ) arn = self.keys[key_id].arn return plaintext, arn def re_encrypt( self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context, ): destination_key_id = self.any_id_to_key_id(destination_key_id) plaintext, decrypting_arn = self.decrypt( ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context, ) new_ciphertext_blob, encrypting_arn = self.encrypt( key_id=destination_key_id, plaintext=plaintext, encryption_context=destination_encryption_context, ) return new_ciphertext_blob, decrypting_arn, encrypting_arn def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens): key_id = self.any_id_to_key_id(key_id) if key_spec: # Note: Actual validation of key_spec is done in kms.responses if key_spec == "AES_128": plaintext_len = 16 else: plaintext_len = 32 else: plaintext_len = number_of_bytes plaintext = os.urandom(plaintext_len) ciphertext_blob, arn = self.encrypt( key_id=key_id, plaintext=plaintext, encryption_context=encryption_context) return plaintext, ciphertext_blob, arn def list_resource_tags(self, key_id): if key_id in self.keys: return self.tagger.list_tags_for_resource(key_id) raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", ) def tag_resource(self, key_id, tags): if key_id in self.keys: self.tagger.tag_resource(key_id, tags) return {} raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", ) def untag_resource(self, key_id, tag_names): if key_id in self.keys: self.tagger.untag_resource_using_names(key_id, tag_names) return {} raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", )
class ECRBackend(BaseBackend): def __init__(self, region_name): self.region_name = region_name self.registry_policy = None self.replication_config = {"rules": []} self.repositories: Dict[str, Repository] = {} self.tagger = TaggingService(tag_name="tags") def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) @staticmethod def default_vpc_endpoint_service(service_region, zones): """Default VPC endpoint service.""" docker_endpoint = { "AcceptanceRequired": False, "AvailabilityZones": zones, "BaseEndpointDnsNames": [f"dkr.ecr.{service_region}.vpce.amazonaws.com"], "ManagesVpcEndpoints": False, "Owner": "amazon", "PrivateDnsName": f"*.dkr.ecr.{service_region}.amazonaws.com", "PrivateDnsNameVerificationState": "verified", "PrivateDnsNames": [{ "PrivateDnsName": f"*.dkr.ecr.{service_region}.amazonaws.com" }], "ServiceId": f"vpce-svc-{BaseBackend.vpce_random_number()}", "ServiceName": f"com.amazonaws.{service_region}.ecr.dkr", "ServiceType": [{ "ServiceType": "Interface" }], "Tags": [], "VpcEndpointPolicySupported": True, } return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "api.ecr", special_service_name="ecr.api") + [docker_endpoint] def _get_repository(self, name, registry_id=None) -> Repository: repo = self.repositories.get(name) reg_id = registry_id or DEFAULT_REGISTRY_ID if not repo or repo.registry_id != reg_id: raise RepositoryNotFoundException(name, reg_id) return repo @staticmethod def _parse_resource_arn(resource_arn) -> EcrRepositoryArn: match = re.match(ECR_REPOSITORY_ARN_PATTERN, resource_arn) if not match: raise InvalidParameterException( "Invalid parameter at 'resourceArn' failed to satisfy constraint: " "'Invalid ARN'") return EcrRepositoryArn(**match.groupdict()) def describe_repositories(self, registry_id=None, repository_names=None): """ maxResults and nextToken not implemented """ if repository_names: for repository_name in repository_names: if repository_name not in self.repositories: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) repositories = [] for repository in self.repositories.values(): # If a registry_id was supplied, ensure this repository matches if registry_id: if repository.registry_id != registry_id: continue # If a list of repository names was supplied, esure this repository # is in that list if repository_names: if repository.name not in repository_names: continue repositories.append(repository.response_object) return repositories def create_repository( self, repository_name, registry_id, encryption_config, image_scan_config, image_tag_mutablility, tags, ): if self.repositories.get(repository_name): raise RepositoryAlreadyExistsException(repository_name, DEFAULT_REGISTRY_ID) repository = Repository( region_name=self.region_name, repository_name=repository_name, registry_id=registry_id, encryption_config=encryption_config, image_scan_config=image_scan_config, image_tag_mutablility=image_tag_mutablility, ) self.repositories[repository_name] = repository self.tagger.tag_resource(repository.arn, tags) return repository def delete_repository(self, repository_name, registry_id=None, force=False): repo = self._get_repository(repository_name, registry_id) if repo.images and not force: raise RepositoryNotEmptyException( repository_name, registry_id or DEFAULT_REGISTRY_ID) self.tagger.delete_all_tags_for_resource(repo.arn) return self.repositories.pop(repository_name) def list_images(self, repository_name, registry_id=None): """ maxResults and filtering not implemented """ repository = None found = False if repository_name in self.repositories: repository = self.repositories[repository_name] if registry_id: if repository.registry_id == registry_id: found = True else: found = True if not found: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) images = [] for image in repository.images: images.append(image) return images def describe_images(self, repository_name, registry_id=None, image_ids=None): repository = self._get_repository(repository_name, registry_id) if image_ids: response = set( repository._get_image(image_id.get("imageTag"), image_id.get("imageDigest")) for image_id in image_ids) else: response = [] for image in repository.images: response.append(image) return response def put_image(self, repository_name, image_manifest, image_tag): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise Exception("{0} is not a repository".format(repository_name)) existing_images = list( filter( lambda x: x.response_object["imageManifest"] == image_manifest, repository.images, )) if not existing_images: # this image is not in ECR yet image = Image(image_tag, image_manifest, repository_name) repository.images.append(image) return image else: # update existing image existing_images[0].update_tag(image_tag) return existing_images[0] def batch_get_image(self, repository_name, registry_id=None, image_ids=None): """ The parameter AcceptedMediaTypes has not yet been implemented """ if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) if not image_ids: raise ParamValidationError( msg='Missing required parameter in input: "imageIds"') response = {"images": [], "failures": []} for image_id in image_ids: found = False for image in repository.images: if ("imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"] ) or ("imageTag" in image_id and image.image_tag == image_id["imageTag"]): found = True response["images"].append(image.response_batch_get_image) if not found: response["failures"].append({ "imageId": { "imageTag": image_id.get("imageTag", "null") }, "failureCode": "ImageNotFound", "failureReason": "Requested image not found", }) return response def batch_delete_image(self, repository_name, registry_id=None, image_ids=None): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) if not image_ids: raise ParamValidationError( msg='Missing required parameter in input: "imageIds"') response = {"imageIds": [], "failures": []} for image_id in image_ids: image_found = False # Is request missing both digest and tag? if "imageDigest" not in image_id and "imageTag" not in image_id: response["failures"].append({ "imageId": {}, "failureCode": "MissingDigestAndTag", "failureReason": "Invalid request parameters: both tag and digest cannot be null", }) continue # If we have a digest, is it valid? if "imageDigest" in image_id: pattern = re.compile(r"^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}") if not pattern.match(image_id.get("imageDigest")): response["failures"].append({ "imageId": { "imageDigest": image_id.get("imageDigest", "null") }, "failureCode": "InvalidImageDigest", "failureReason": "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'", }) continue for num, image in enumerate(repository.images): # Search by matching both digest and tag if "imageDigest" in image_id and "imageTag" in image_id: if (image_id["imageDigest"] == image.get_image_digest() and image_id["imageTag"] in image.image_tags): image_found = True for image_tag in reversed(image.image_tags): repository.images[num].image_tag = image_tag response["imageIds"].append( image.response_batch_delete_image) repository.images[num].remove_tag(image_tag) del repository.images[num] # Search by matching digest elif ("imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"]): image_found = True for image_tag in reversed(image.image_tags): repository.images[num].image_tag = image_tag response["imageIds"].append( image.response_batch_delete_image) repository.images[num].remove_tag(image_tag) del repository.images[num] # Search by matching tag elif ("imageTag" in image_id and image_id["imageTag"] in image.image_tags): image_found = True repository.images[num].image_tag = image_id["imageTag"] response["imageIds"].append( image.response_batch_delete_image) if len(image.image_tags) > 1: repository.images[num].remove_tag(image_id["imageTag"]) else: repository.images.remove(image) if not image_found: failure_response = { "imageId": {}, "failureCode": "ImageNotFound", "failureReason": "Requested image not found", } if "imageDigest" in image_id: failure_response["imageId"]["imageDigest"] = image_id.get( "imageDigest", "null") if "imageTag" in image_id: failure_response["imageId"]["imageTag"] = image_id.get( "imageTag", "null") response["failures"].append(failure_response) return response def list_tags_for_resource(self, arn): resource = self._parse_resource_arn(arn) repo = self._get_repository(resource.repo_name, resource.account_id) return self.tagger.list_tags_for_resource(repo.arn) def tag_resource(self, arn, tags): resource = self._parse_resource_arn(arn) repo = self._get_repository(resource.repo_name, resource.account_id) self.tagger.tag_resource(repo.arn, tags) return {} def untag_resource(self, arn, tag_keys): resource = self._parse_resource_arn(arn) repo = self._get_repository(resource.repo_name, resource.account_id) self.tagger.untag_resource_using_names(repo.arn, tag_keys) return {} def put_image_tag_mutability(self, registry_id, repository_name, image_tag_mutability): if image_tag_mutability not in ["IMMUTABLE", "MUTABLE"]: raise InvalidParameterException( "Invalid parameter at 'imageTagMutability' failed to satisfy constraint: " "'Member must satisfy enum value set: [IMMUTABLE, MUTABLE]'") repo = self._get_repository(repository_name, registry_id) repo.update(image_tag_mutability=image_tag_mutability) return { "registryId": repo.registry_id, "repositoryName": repository_name, "imageTagMutability": repo.image_tag_mutability, } def put_image_scanning_configuration(self, registry_id, repository_name, image_scan_config): repo = self._get_repository(repository_name, registry_id) repo.update(image_scan_config=image_scan_config) return { "registryId": repo.registry_id, "repositoryName": repository_name, "imageScanningConfiguration": repo.image_scanning_configuration, } def set_repository_policy(self, registry_id, repository_name, policy_text): repo = self._get_repository(repository_name, registry_id) try: iam_policy_document_validator = IAMPolicyDocumentValidator( policy_text) # the repository policy can be defined without a resource field iam_policy_document_validator._validate_resource_exist = lambda: None # the repository policy can have the old version 2008-10-17 iam_policy_document_validator._validate_version = lambda: None iam_policy_document_validator.validate() except MalformedPolicyDocument: raise InvalidParameterException( "Invalid parameter at 'PolicyText' failed to satisfy constraint: " "'Invalid repository policy provided'") repo.policy = policy_text return { "registryId": repo.registry_id, "repositoryName": repository_name, "policyText": repo.policy, } def get_repository_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) if not repo.policy: raise RepositoryPolicyNotFoundException(repository_name, repo.registry_id) return { "registryId": repo.registry_id, "repositoryName": repository_name, "policyText": repo.policy, } def delete_repository_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) policy = repo.policy if not policy: raise RepositoryPolicyNotFoundException(repository_name, repo.registry_id) repo.policy = None return { "registryId": repo.registry_id, "repositoryName": repository_name, "policyText": policy, } def put_lifecycle_policy(self, registry_id, repository_name, lifecycle_policy_text): repo = self._get_repository(repository_name, registry_id) validator = EcrLifecyclePolicyValidator(lifecycle_policy_text) validator.validate() repo.lifecycle_policy = lifecycle_policy_text return { "registryId": repo.registry_id, "repositoryName": repository_name, "lifecyclePolicyText": repo.lifecycle_policy, } def get_lifecycle_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) if not repo.lifecycle_policy: raise LifecyclePolicyNotFoundException(repository_name, repo.registry_id) return { "registryId": repo.registry_id, "repositoryName": repository_name, "lifecyclePolicyText": repo.lifecycle_policy, "lastEvaluatedAt": iso_8601_datetime_without_milliseconds(datetime.utcnow()), } def delete_lifecycle_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) policy = repo.lifecycle_policy if not policy: raise LifecyclePolicyNotFoundException(repository_name, repo.registry_id) repo.lifecycle_policy = None return { "registryId": repo.registry_id, "repositoryName": repository_name, "lifecyclePolicyText": policy, "lastEvaluatedAt": iso_8601_datetime_without_milliseconds(datetime.utcnow()), } def _validate_registry_policy_action(self, policy_text): # only CreateRepository & ReplicateImage actions are allowed VALID_ACTIONS = {"ecr:CreateRepository", "ecr:ReplicateImage"} policy = json.loads(policy_text) for statement in policy["Statement"]: action = statement["Action"] if isinstance(action, str): action = [action] if set(action) - VALID_ACTIONS: raise MalformedPolicyDocument() def put_registry_policy(self, policy_text): try: iam_policy_document_validator = IAMPolicyDocumentValidator( policy_text) iam_policy_document_validator.validate() self._validate_registry_policy_action(policy_text) except MalformedPolicyDocument: raise InvalidParameterException( "Invalid parameter at 'PolicyText' failed to satisfy constraint: " "'Invalid registry policy provided'") self.registry_policy = policy_text return { "registryId": get_account_id(), "policyText": policy_text, } def get_registry_policy(self): if not self.registry_policy: raise RegistryPolicyNotFoundException(get_account_id()) return { "registryId": get_account_id(), "policyText": self.registry_policy, } def delete_registry_policy(self): policy = self.registry_policy if not policy: raise RegistryPolicyNotFoundException(get_account_id()) self.registry_policy = None return { "registryId": get_account_id(), "policyText": policy, } def start_image_scan(self, registry_id, repository_name, image_id): repo = self._get_repository(repository_name, registry_id) image = repo._get_image(image_id.get("imageTag"), image_id.get("imageDigest")) # scanning an image is only allowed once per day if image.last_scan and image.last_scan.date() == datetime.today().date( ): raise LimitExceededException() image.last_scan = datetime.today() return { "registryId": repo.registry_id, "repositoryName": repository_name, "imageId": { "imageDigest": image.image_digest, "imageTag": image.image_tag, }, "imageScanStatus": { "status": "IN_PROGRESS" }, } def describe_image_scan_findings(self, registry_id, repository_name, image_id): repo = self._get_repository(repository_name, registry_id) image = repo._get_image(image_id.get("imageTag"), image_id.get("imageDigest")) if not image.last_scan: image_id_rep = "{{imageDigest:'{0}', imageTag:'{1}'}}".format( image_id.get("imageDigest") or "null", image_id.get("imageTag") or "null", ) raise ScanNotFoundException( image_id=image_id_rep, repository_name=repository_name, registry_id=repo.registry_id, ) return { "registryId": repo.registry_id, "repositoryName": repository_name, "imageId": { "imageDigest": image.image_digest, "imageTag": image.image_tag, }, "imageScanStatus": { "status": "COMPLETE", "description": "The scan was completed successfully.", }, "imageScanFindings": { "imageScanCompletedAt": iso_8601_datetime_without_milliseconds(image.last_scan), "vulnerabilitySourceUpdatedAt": iso_8601_datetime_without_milliseconds(datetime.utcnow()), "findings": [{ "name": "CVE-9999-9999", "uri": "https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-9999-9999", "severity": "HIGH", "attributes": [ { "key": "package_version", "value": "9.9.9" }, { "key": "package_name", "value": "moto_fake" }, { "key": "CVSS2_VECTOR", "value": "AV:N/AC:L/Au:N/C:P/I:P/A:P", }, { "key": "CVSS2_SCORE", "value": "7.5" }, ], }], "findingSeverityCounts": { "HIGH": 1 }, }, } def put_replication_configuration(self, replication_config): rules = replication_config["rules"] if len(rules) > 1: raise ValidationException("This feature is disabled") if len(rules) == 1: for dest in rules[0]["destinations"]: if (dest["region"] == self.region_name and dest["registryId"] == DEFAULT_REGISTRY_ID): raise InvalidParameterException( "Invalid parameter at 'replicationConfiguration' failed to satisfy constraint: " "'Replication destination cannot be the same as the source registry'" ) self.replication_config = replication_config return {"replicationConfiguration": replication_config} def describe_registry(self): return { "registryId": DEFAULT_REGISTRY_ID, "replicationConfiguration": self.replication_config, }
class EventsBackend(BaseBackend): ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$") STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$") def __init__(self, region_name): self.rules = {} # This array tracks the order in which the rules have been added, since # 2.6 doesn't have OrderedDicts. self.rules_order = [] self.next_tokens = {} self.region_name = region_name self.event_buses = {} self.event_sources = {} self.archives = {} self.replays = {} self.tagger = TaggingService() self._add_default_event_bus() def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def _add_default_event_bus(self): self.event_buses["default"] = EventBus(self.region_name, "default") def _get_rule_by_index(self, i): return self.rules.get(self.rules_order[i]) def _gen_next_token(self, index): token = os.urandom(128).encode("base64") self.next_tokens[token] = index return token def _process_token_and_limits(self, array_len, next_token=None, limit=None): start_index = 0 end_index = array_len new_next_token = None if next_token: start_index = self.next_tokens.pop(next_token, 0) if limit is not None: new_end_index = start_index + int(limit) if new_end_index < end_index: end_index = new_end_index new_next_token = self._gen_next_token(end_index) return start_index, end_index, new_next_token def _get_event_bus(self, name): event_bus_name = name.split("/")[-1] event_bus = self.event_buses.get(event_bus_name) if not event_bus: raise ResourceNotFoundException( "Event bus {} does not exist.".format(event_bus_name)) return event_bus def _get_replay(self, name): replay = self.replays.get(name) if not replay: raise ResourceNotFoundException( "Replay {} does not exist.".format(name)) return replay def delete_rule(self, name): self.rules_order.pop(self.rules_order.index(name)) arn = self.rules.get(name).arn if self.tagger.has_tags(arn): self.tagger.delete_all_tags_for_resource(arn) return self.rules.pop(name) is not None def describe_rule(self, name): return self.rules.get(name) def disable_rule(self, name): if name in self.rules: self.rules[name].disable() return True return False def enable_rule(self, name): if name in self.rules: self.rules[name].enable() return True return False def list_rule_names_by_target(self, target_arn, next_token=None, limit=None): matching_rules = [] return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( len(self.rules), next_token, limit) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) for target in rule.targets: if target["Arn"] == target_arn: matching_rules.append(rule.name) return_obj["RuleNames"] = matching_rules if new_next_token is not None: return_obj["NextToken"] = new_next_token return return_obj def list_rules(self, prefix=None, next_token=None, limit=None): match_string = ".*" if prefix is not None: match_string = "^" + prefix + match_string match_regex = re.compile(match_string) matching_rules = [] return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( len(self.rules), next_token, limit) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) if match_regex.match(rule.name): matching_rules.append(rule) return_obj["Rules"] = matching_rules if new_next_token is not None: return_obj["NextToken"] = new_next_token return return_obj def list_targets_by_rule(self, rule, next_token=None, limit=None): # We'll let a KeyError exception be thrown for response to handle if # rule doesn't exist. rule = self.rules[rule] start_index, end_index, new_next_token = self._process_token_and_limits( len(rule.targets), next_token, limit) returned_targets = [] return_obj = {} for i in range(start_index, end_index): returned_targets.append(rule.targets[i]) return_obj["Targets"] = returned_targets if new_next_token is not None: return_obj["NextToken"] = new_next_token return return_obj def update_rule(self, rule, **kwargs): rule.event_pattern = kwargs.get("EventPattern") or rule.event_pattern rule.schedule_exp = kwargs.get( "ScheduleExpression") or rule.schedule_exp rule.state = kwargs.get("State") or rule.state rule.description = kwargs.get("Description") or rule.description rule.role_arn = kwargs.get("RoleArn") or rule.role_arn rule.event_bus_name = kwargs.get("EventBusName") or rule.event_bus_name def put_rule(self, name, **kwargs): if kwargs.get("ScheduleExpression" ) and kwargs.get("EventBusName") != "default": raise ValidationException( "ScheduleExpression is supported only on the default event bus." ) if name in self.rules: self.update_rule(self.rules[name], **kwargs) new_rule = self.rules[name] else: new_rule = Rule(name, self.region_name, **kwargs) self.rules[new_rule.name] = new_rule self.rules_order.append(new_rule.name) return new_rule def put_targets(self, name, event_bus_name, targets): # super simple ARN check invalid_arn = next( (target["Arn"] for target in targets if not re.match(r"arn:[\d\w:\-/]*", target["Arn"])), None, ) if invalid_arn: raise ValidationException( "Parameter {} is not valid. " "Reason: Provided Arn is not in correct format.".format( invalid_arn)) for target in targets: arn = target["Arn"] if (":sqs:" in arn and arn.endswith(".fifo") and not target.get("SqsParameters")): raise ValidationException( "Parameter(s) SqsParameters must be specified for target: {}." .format(target["Id"])) rule = self.rules.get(name) if not rule: raise ResourceNotFoundException( "Rule {0} does not exist on EventBus {1}.".format( name, event_bus_name)) rule.put_targets(targets) def put_events(self, events): num_events = len(events) if num_events > 10: # the exact error text is longer, the Value list consists of all the put events raise ValidationException( "1 validation error detected: " "Value '[PutEventsRequestEntry]' at 'entries' failed to satisfy constraint: " "Member must have length less than or equal to 10") entries = [] for event in events: if "Source" not in event: entries.append({ "ErrorCode": "InvalidArgument", "ErrorMessage": "Parameter Source is not valid. Reason: Source is a required argument.", }) elif "DetailType" not in event: entries.append({ "ErrorCode": "InvalidArgument", "ErrorMessage": "Parameter DetailType is not valid. Reason: DetailType is a required argument.", }) elif "Detail" not in event: entries.append({ "ErrorCode": "InvalidArgument", "ErrorMessage": "Parameter Detail is not valid. Reason: Detail is a required argument.", }) else: try: json.loads(event["Detail"]) except ValueError: # json.JSONDecodeError exists since Python 3.5 entries.append({ "ErrorCode": "MalformedDetail", "ErrorMessage": "Detail is malformed.", }) continue event_id = str(uuid4()) entries.append({"EventId": event_id}) # if 'EventBusName' is not especially set, it will be sent to the default one event_bus_name = event.get("EventBusName", "default") for rule in self.rules.values(): rule.send_to_targets( event_bus_name, { "version": "0", "id": event_id, "detail-type": event["DetailType"], "source": event["Source"], "account": ACCOUNT_ID, "time": event.get("Time", unix_time(datetime.utcnow())), "region": self.region_name, "resources": event.get("Resources", []), "detail": json.loads(event["Detail"]), }, ) return entries def remove_targets(self, name, event_bus_name, ids): rule = self.rules.get(name) if not rule: raise ResourceNotFoundException( "Rule {0} does not exist on EventBus {1}.".format( name, event_bus_name)) rule.remove_targets(ids) def test_event_pattern(self): raise NotImplementedError() def put_permission(self, event_bus_name, action, principal, statement_id): if not event_bus_name: event_bus_name = "default" event_bus = self.describe_event_bus(event_bus_name) if action is None or action != "events:PutEvents": raise JsonRESTError( "ValidationException", "Provided value in parameter 'action' is not supported.", ) if principal is None or self.ACCOUNT_ID.match(principal) is None: raise JsonRESTError("InvalidParameterValue", r"Principal must match ^(\d{1,12}|\*)$") if statement_id is None or self.STATEMENT_ID.match( statement_id) is None: raise JsonRESTError( "InvalidParameterValue", r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$") event_bus._permissions[statement_id] = { "Action": action, "Principal": principal, } def remove_permission(self, event_bus_name, statement_id): if not event_bus_name: event_bus_name = "default" event_bus = self.describe_event_bus(event_bus_name) if not len(event_bus._permissions): raise JsonRESTError("ResourceNotFoundException", "EventBus does not have a policy.") if not event_bus._permissions.pop(statement_id, None): raise JsonRESTError( "ResourceNotFoundException", "Statement with the provided id does not exist.", ) def describe_event_bus(self, name): if not name: name = "default" event_bus = self._get_event_bus(name) return event_bus def create_event_bus(self, name, event_source_name=None): if name in self.event_buses: raise JsonRESTError( "ResourceAlreadyExistsException", "Event bus {} already exists.".format(name), ) if not event_source_name and "/" in name: raise JsonRESTError("ValidationException", "Event bus name must not contain '/'.") if event_source_name and event_source_name not in self.event_sources: raise JsonRESTError( "ResourceNotFoundException", "Event source {} does not exist.".format(event_source_name), ) self.event_buses[name] = EventBus(self.region_name, name) return self.event_buses[name] def list_event_buses(self, name_prefix): if name_prefix: return [ event_bus for event_bus in self.event_buses.values() if event_bus.name.startswith(name_prefix) ] return list(self.event_buses.values()) def delete_event_bus(self, name): if name == "default": raise JsonRESTError("ValidationException", "Cannot delete event bus default.") self.event_buses.pop(name, None) def list_tags_for_resource(self, arn): name = arn.split("/")[-1] if name in self.rules: return self.tagger.list_tags_for_resource(self.rules[name].arn) raise ResourceNotFoundException( "Rule {0} does not exist on EventBus default.".format(name)) def tag_resource(self, arn, tags): name = arn.split("/")[-1] if name in self.rules: self.tagger.tag_resource(self.rules[name].arn, tags) return {} raise ResourceNotFoundException( "Rule {0} does not exist on EventBus default.".format(name)) def untag_resource(self, arn, tag_names): name = arn.split("/")[-1] if name in self.rules: self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names) return {} raise ResourceNotFoundException( "Rule {0} does not exist on EventBus default.".format(name)) def create_archive(self, name, source_arn, description, event_pattern, retention): if len(name) > 48: raise ValidationException( " 1 validation error detected: " "Value '{}' at 'archiveName' failed to satisfy constraint: " "Member must have length less than or equal to 48".format( name)) event_bus = self._get_event_bus(source_arn) if name in self.archives: raise ResourceAlreadyExistsException( "Archive {} already exists.".format(name)) archive = Archive(self.region_name, name, source_arn, description, event_pattern, retention) rule_event_pattern = json.loads(event_pattern or "{}") rule_event_pattern["replay-name"] = [{"exists": False}] rule = self.put_rule( "Events-Archive-{}".format(name), **{ "EventPattern": json.dumps(rule_event_pattern), "EventBusName": event_bus.name, "ManagedBy": "prod.vhs.events.aws.internal", }) self.put_targets( rule.name, rule.event_bus_name, [{ "Id": rule.name, "Arn": "arn:aws:events:{}:::".format(self.region_name), "InputTransformer": { "InputPathsMap": {}, "InputTemplate": json.dumps({ "archive-arn": "{0}:{1}".format(archive.arn, archive.uuid), "event": "<aws.events.event.json>", "ingestion-time": "<aws.events.event.ingestion-time>", }), }, }], ) self.archives[name] = archive return archive def describe_archive(self, name): archive = self.archives.get(name) if not archive: raise ResourceNotFoundException( "Archive {} does not exist.".format(name)) return archive.describe() def list_archives(self, name_prefix, source_arn, state): if [name_prefix, source_arn, state].count(None) < 2: raise ValidationException( "At most one filter is allowed for ListArchives. " "Use either : State, EventSourceArn, or NamePrefix.") if state and state not in Archive.VALID_STATES: raise ValidationException( "1 validation error detected: " "Value '{0}' at 'state' failed to satisfy constraint: " "Member must satisfy enum value set: " "[{1}]".format(state, ", ".join(Archive.VALID_STATES))) if [name_prefix, source_arn, state].count(None) == 3: return [ archive.describe_short() for archive in self.archives.values() ] result = [] for archive in self.archives.values(): if name_prefix and archive.name.startswith(name_prefix): result.append(archive.describe_short()) elif source_arn and archive.source_arn == source_arn: result.append(archive.describe_short()) elif state and archive.state == state: result.append(archive.describe_short()) return result def update_archive(self, name, description, event_pattern, retention): archive = self.archives.get(name) if not archive: raise ResourceNotFoundException( "Archive {} does not exist.".format(name)) archive.update(description, event_pattern, retention) return { "ArchiveArn": archive.arn, "CreationTime": archive.creation_time, "State": archive.state, } def delete_archive(self, name): archive = self.archives.get(name) if not archive: raise ResourceNotFoundException( "Archive {} does not exist.".format(name)) archive.delete(self.region_name) def start_replay(self, name, description, source_arn, start_time, end_time, destination): event_bus_arn = destination["Arn"] event_bus_arn_pattern = r"^arn:aws:events:[a-zA-Z0-9-]+:\d{12}:event-bus/" if not re.match(event_bus_arn_pattern, event_bus_arn): raise ValidationException( "Parameter Destination.Arn is not valid. " "Reason: Must contain an event bus ARN.") self._get_event_bus(event_bus_arn) archive_name = source_arn.split("/")[-1] archive = self.archives.get(archive_name) if not archive: raise ValidationException( "Parameter EventSourceArn is not valid. " "Reason: Archive {} does not exist.".format(archive_name)) if event_bus_arn != archive.source_arn: raise ValidationException( "Parameter Destination.Arn is not valid. " "Reason: Cross event bus replay is not permitted.") if start_time > end_time: raise ValidationException( "Parameter EventEndTime is not valid. " "Reason: EventStartTime must be before EventEndTime.") if name in self.replays: raise ResourceAlreadyExistsException( "Replay {} already exists.".format(name)) replay = Replay( self.region_name, name, description, source_arn, start_time, end_time, destination, ) self.replays[name] = replay replay.replay_events(archive) return { "ReplayArn": replay.arn, "ReplayStartTime": replay.start_time, "State": ReplayState.STARTING. value, # the replay will be done before returning the response } def describe_replay(self, name): replay = self._get_replay(name) return replay.describe() def list_replays(self, name_prefix, source_arn, state): if [name_prefix, source_arn, state].count(None) < 2: raise ValidationException( "At most one filter is allowed for ListReplays. " "Use either : State, EventSourceArn, or NamePrefix.") valid_states = sorted([item.value for item in ReplayState]) if state and state not in valid_states: raise ValidationException( "1 validation error detected: " "Value '{0}' at 'state' failed to satisfy constraint: " "Member must satisfy enum value set: " "[{1}]".format(state, ", ".join(valid_states))) if [name_prefix, source_arn, state].count(None) == 3: return [ replay.describe_short() for replay in self.replays.values() ] result = [] for replay in self.replays.values(): if name_prefix and replay.name.startswith(name_prefix): result.append(replay.describe_short()) elif source_arn and replay.source_arn == source_arn: result.append(replay.describe_short()) elif state and replay.state == state: result.append(replay.describe_short()) return result def cancel_replay(self, name): replay = self._get_replay(name) # replays in the state 'COMPLETED' can't be canceled, # but the implementation is done synchronously, # so they are done right after the start if replay.state not in [ ReplayState.STARTING, ReplayState.RUNNING, ReplayState.COMPLETED, ]: raise IllegalStatusException( "Replay {} is not in a valid state for this operation.".format( name)) replay.state = ReplayState.CANCELLED return {"ReplayArn": replay.arn, "State": ReplayState.CANCELLING.value}
class EventsBackend(BaseBackend): ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$") STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$") def __init__(self, region_name): self.rules = {} # This array tracks the order in which the rules have been added, since # 2.6 doesn't have OrderedDicts. self.rules_order = [] self.next_tokens = {} self.region_name = region_name self.event_buses = {} self.event_sources = {} self.tagger = TaggingService() self._add_default_event_bus() def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def _add_default_event_bus(self): self.event_buses["default"] = EventBus(self.region_name, "default") def _get_rule_by_index(self, i): return self.rules.get(self.rules_order[i]) def _gen_next_token(self, index): token = os.urandom(128).encode("base64") self.next_tokens[token] = index return token def _process_token_and_limits(self, array_len, next_token=None, limit=None): start_index = 0 end_index = array_len new_next_token = None if next_token: start_index = self.next_tokens.pop(next_token, 0) if limit is not None: new_end_index = start_index + int(limit) if new_end_index < end_index: end_index = new_end_index new_next_token = self._gen_next_token(end_index) return start_index, end_index, new_next_token def delete_rule(self, name): self.rules_order.pop(self.rules_order.index(name)) arn = self.rules.get(name).arn if self.tagger.has_tags(arn): self.tagger.delete_all_tags_for_resource(arn) return self.rules.pop(name) is not None def describe_rule(self, name): return self.rules.get(name) def disable_rule(self, name): if name in self.rules: self.rules[name].disable() return True return False def enable_rule(self, name): if name in self.rules: self.rules[name].enable() return True return False def list_rule_names_by_target(self, target_arn, next_token=None, limit=None): matching_rules = [] return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( len(self.rules), next_token, limit) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) for target in rule.targets: if target["Arn"] == target_arn: matching_rules.append(rule.name) return_obj["RuleNames"] = matching_rules if new_next_token is not None: return_obj["NextToken"] = new_next_token return return_obj def list_rules(self, prefix=None, next_token=None, limit=None): match_string = ".*" if prefix is not None: match_string = "^" + prefix + match_string match_regex = re.compile(match_string) matching_rules = [] return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( len(self.rules), next_token, limit) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) if match_regex.match(rule.name): matching_rules.append(rule) return_obj["Rules"] = matching_rules if new_next_token is not None: return_obj["NextToken"] = new_next_token return return_obj def list_targets_by_rule(self, rule, next_token=None, limit=None): # We'll let a KeyError exception be thrown for response to handle if # rule doesn't exist. rule = self.rules[rule] start_index, end_index, new_next_token = self._process_token_and_limits( len(rule.targets), next_token, limit) returned_targets = [] return_obj = {} for i in range(start_index, end_index): returned_targets.append(rule.targets[i]) return_obj["Targets"] = returned_targets if new_next_token is not None: return_obj["NextToken"] = new_next_token return return_obj def put_rule(self, name, **kwargs): new_rule = Rule(name, self.region_name, **kwargs) self.rules[new_rule.name] = new_rule self.rules_order.append(new_rule.name) return new_rule def put_targets(self, name, targets): rule = self.rules.get(name) if rule: rule.put_targets(targets) return True return False def put_events(self, events): num_events = len(events) if num_events < 1: raise JsonRESTError("ValidationError", "Need at least 1 event") elif num_events > 10: # the exact error text is longer, the Value list consists of all the put events raise ValidationException( "1 validation error detected: " "Value '[PutEventsRequestEntry]' at 'entries' failed to satisfy constraint: " "Member must have length less than or equal to 10") entries = [] for event in events: if "Source" not in event: entries.append({ "ErrorCode": "InvalidArgument", "ErrorMessage": "Parameter Source is not valid. Reason: Source is a required argument.", }) elif "DetailType" not in event: entries.append({ "ErrorCode": "InvalidArgument", "ErrorMessage": "Parameter DetailType is not valid. Reason: DetailType is a required argument.", }) elif "Detail" not in event: entries.append({ "ErrorCode": "InvalidArgument", "ErrorMessage": "Parameter Detail is not valid. Reason: Detail is a required argument.", }) else: try: json.loads(event["Detail"]) except ValueError: # json.JSONDecodeError exists since Python 3.5 entries.append({ "ErrorCode": "MalformedDetail", "ErrorMessage": "Detail is malformed.", }) continue entries.append({"EventId": str(uuid4())}) # We dont really need to store the events yet return entries def remove_targets(self, name, ids): rule = self.rules.get(name) if rule: rule.remove_targets(ids) return {"FailedEntries": [], "FailedEntryCount": 0} else: raise JsonRESTError( "ResourceNotFoundException", "An entity that you specified does not exist", ) def test_event_pattern(self): raise NotImplementedError() def put_permission(self, event_bus_name, action, principal, statement_id): if not event_bus_name: event_bus_name = "default" event_bus = self.describe_event_bus(event_bus_name) if action is None or action != "events:PutEvents": raise JsonRESTError( "ValidationException", "Provided value in parameter 'action' is not supported.", ) if principal is None or self.ACCOUNT_ID.match(principal) is None: raise JsonRESTError("InvalidParameterValue", r"Principal must match ^(\d{1,12}|\*)$") if statement_id is None or self.STATEMENT_ID.match( statement_id) is None: raise JsonRESTError( "InvalidParameterValue", r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$") event_bus._permissions[statement_id] = { "Action": action, "Principal": principal, } def remove_permission(self, event_bus_name, statement_id): if not event_bus_name: event_bus_name = "default" event_bus = self.describe_event_bus(event_bus_name) if not len(event_bus._permissions): raise JsonRESTError("ResourceNotFoundException", "EventBus does not have a policy.") if not event_bus._permissions.pop(statement_id, None): raise JsonRESTError( "ResourceNotFoundException", "Statement with the provided id does not exist.", ) def describe_event_bus(self, name): if not name: name = "default" event_bus = self.event_buses.get(name) if not event_bus: raise JsonRESTError("ResourceNotFoundException", "Event bus {} does not exist.".format(name)) return event_bus def create_event_bus(self, name, event_source_name=None): if name in self.event_buses: raise JsonRESTError( "ResourceAlreadyExistsException", "Event bus {} already exists.".format(name), ) if not event_source_name and "/" in name: raise JsonRESTError("ValidationException", "Event bus name must not contain '/'.") if event_source_name and event_source_name not in self.event_sources: raise JsonRESTError( "ResourceNotFoundException", "Event source {} does not exist.".format(event_source_name), ) self.event_buses[name] = EventBus(self.region_name, name) return self.event_buses[name] def list_event_buses(self, name_prefix): if name_prefix: return [ event_bus for event_bus in self.event_buses.values() if event_bus.name.startswith(name_prefix) ] return list(self.event_buses.values()) def delete_event_bus(self, name): if name == "default": raise JsonRESTError("ValidationException", "Cannot delete event bus default.") self.event_buses.pop(name, None) def list_tags_for_resource(self, arn): name = arn.split("/")[-1] if name in self.rules: return self.tagger.list_tags_for_resource(self.rules[name].arn) raise ResourceNotFoundException( "Rule {0} does not exist on EventBus default.".format(name)) def tag_resource(self, arn, tags): name = arn.split("/")[-1] if name in self.rules: self.tagger.tag_resource(self.rules[name].arn, tags) return {} raise ResourceNotFoundException( "Rule {0} does not exist on EventBus default.".format(name)) def untag_resource(self, arn, tag_names): name = arn.split("/")[-1] if name in self.rules: self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names) return {} raise ResourceNotFoundException( "Rule {0} does not exist on EventBus default.".format(name))
class Route53ResolverBackend(BaseBackend): """Implementation of Route53Resolver APIs.""" def __init__(self, region_name, account_id): super().__init__(region_name, account_id) self.resolver_endpoints = {} # Key is self-generated ID (endpoint_id) self.resolver_rules = {} # Key is self-generated ID (rule_id) self.resolver_rule_associations = {} # Key is resolver_rule_association_id) self.tagger = TaggingService() @staticmethod def default_vpc_endpoint_service(service_region, zones): """List of dicts representing default VPC endpoints for this service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "route53resolver" ) def associate_resolver_rule(self, region, resolver_rule_id, name, vpc_id): validate_args( [("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)] ) associations = [ x for x in self.resolver_rule_associations.values() if x.region == region ] if len(associations) > ResolverRuleAssociation.MAX_RULE_ASSOCIATIONS_PER_REGION: # This error message was not verified to be the same for AWS. raise LimitExceededException( f"Account '{get_account_id()}' has exceeded 'max-rule-association'" ) if resolver_rule_id not in self.resolver_rules: raise ResourceNotFoundException( f"Resolver rule with ID '{resolver_rule_id}' does not exist." ) vpcs = ec2_backends[region].describe_vpcs() if vpc_id not in [x.id for x in vpcs]: raise InvalidParameterException(f"The vpc ID '{vpc_id}' does not exist") # Can't duplicate resolver rule, vpc id associations. for association in self.resolver_rule_associations.values(): if ( resolver_rule_id == association.resolver_rule_id and vpc_id == association.vpc_id ): raise InvalidRequestException( f"Cannot associate rules with same domain name with same " f"VPC. Conflict with resolver rule '{resolver_rule_id}'" ) rule_association_id = f"rslvr-rrassoc-{get_random_hex(17)}" rule_association = ResolverRuleAssociation( region, rule_association_id, resolver_rule_id, vpc_id, name ) self.resolver_rule_associations[rule_association_id] = rule_association return rule_association @staticmethod def _verify_subnet_ips(region, ip_addresses, initial=True): """Perform additional checks on the IPAddresses. NOTE: This does not include IPv6 addresses. """ # only initial endpoint creation requires atleast two ip addresses if initial: if len(ip_addresses) < 2: raise InvalidRequestException( "Resolver endpoint needs to have at least 2 IP addresses" ) subnets = defaultdict(set) for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]: try: subnet_info = ec2_backends[region].get_all_subnets( subnet_ids=[subnet_id] )[0] except InvalidSubnetIdError as exc: raise InvalidParameterException( f"The subnet ID '{subnet_id}' does not exist" ) from exc # IP in IPv4 CIDR range and not reserved? if ip_address(ip_addr) in subnet_info.reserved_ips or ip_address( ip_addr ) not in ip_network(subnet_info.cidr_block): raise InvalidRequestException( f"IP address '{ip_addr}' is either not in subnet " f"'{subnet_id}' CIDR range or is reserved" ) if ip_addr in subnets[subnet_id]: raise ResourceExistsException( f"The IP address '{ip_addr}' in subnet '{subnet_id}' is already in use" ) subnets[subnet_id].add(ip_addr) @staticmethod def _verify_security_group_ids(region, security_group_ids): """Perform additional checks on the security groups.""" if len(security_group_ids) > 10: raise InvalidParameterException("Maximum of 10 security groups are allowed") for group_id in security_group_ids: if not group_id.startswith("sg-"): raise InvalidParameterException( f"Malformed security group ID: Invalid id: '{group_id}' " f"(expecting 'sg-...')" ) try: ec2_backends[region].describe_security_groups(group_ids=[group_id]) except InvalidSecurityGroupNotFoundError as exc: raise ResourceNotFoundException( f"The security group '{group_id}' does not exist" ) from exc def create_resolver_endpoint( self, region, creator_request_id, name, security_group_ids, direction, ip_addresses, tags, ): # pylint: disable=too-many-arguments """ Return description for a newly created resolver endpoint. NOTE: IPv6 IPs are currently not being filtered when calculating the create_resolver_endpoint() IpAddresses. """ validate_args( [ ("creatorRequestId", creator_request_id), ("direction", direction), ("ipAddresses", ip_addresses), ("name", name), ("securityGroupIds", security_group_ids), ("ipAddresses.subnetId", ip_addresses), ] ) errmsg = self.tagger.validate_tags( tags or [], limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT ) if errmsg: raise TagValidationException(errmsg) endpoints = [x for x in self.resolver_endpoints.values() if x.region == region] if len(endpoints) > ResolverEndpoint.MAX_ENDPOINTS_PER_REGION: raise LimitExceededException( f"Account '{get_account_id()}' has exceeded 'max-endpoints'" ) for x in ip_addresses: if not x.get("Ip"): subnet_info = ec2_backends[region].get_all_subnets( subnet_ids=[x["SubnetId"]] )[0] x["Ip"] = subnet_info.get_available_subnet_ip(self) self._verify_subnet_ips(region, ip_addresses) self._verify_security_group_ids(region, security_group_ids) if creator_request_id in [ x.creator_request_id for x in self.resolver_endpoints.values() ]: raise ResourceExistsException( f"Resolver endpoint with creator request ID " f"'{creator_request_id}' already exists" ) endpoint_id = ( f"rslvr-{'in' if direction == 'INBOUND' else 'out'}-{get_random_hex(17)}" ) resolver_endpoint = ResolverEndpoint( region, endpoint_id, creator_request_id, security_group_ids, direction, ip_addresses, name, ) self.resolver_endpoints[endpoint_id] = resolver_endpoint self.tagger.tag_resource(resolver_endpoint.arn, tags or []) return resolver_endpoint def create_resolver_rule( self, region, creator_request_id, name, rule_type, domain_name, target_ips, resolver_endpoint_id, tags, ): # pylint: disable=too-many-arguments """Return description for a newly created resolver rule.""" validate_args( [ ("creatorRequestId", creator_request_id), ("ruleType", rule_type), ("domainName", domain_name), ("name", name), *[("targetIps.port", x) for x in target_ips], ("resolverEndpointId", resolver_endpoint_id), ] ) errmsg = self.tagger.validate_tags( tags or [], limit=ResolverRule.MAX_TAGS_PER_RESOLVER_RULE ) if errmsg: raise TagValidationException(errmsg) rules = [x for x in self.resolver_rules.values() if x.region == region] if len(rules) > ResolverRule.MAX_RULES_PER_REGION: # Did not verify that this is the actual error message. raise LimitExceededException( f"Account '{get_account_id()}' has exceeded 'max-rules'" ) # Per the AWS documentation and as seen with the AWS console, target # ips are only relevant when the value of Rule is FORWARD. However, # boto3 ignores this condition and so shall we. for ip_addr in [x["Ip"] for x in target_ips]: try: # boto3 fails with an InternalServiceException if IPv6 # addresses are used, which isn't helpful. if not isinstance(ip_address(ip_addr), IPv4Address): raise InvalidParameterException( f"Only IPv4 addresses may be used: '{ip_addr}'" ) except ValueError as exc: raise InvalidParameterException( f"Invalid IP address: '{ip_addr}'" ) from exc # The boto3 documentation indicates that ResolverEndpoint is # optional, as does the AWS documention. But if resolver_endpoint_id # is set to None or an empty string, it results in boto3 raising # a ParamValidationError either regarding the type or len of string. if resolver_endpoint_id: if resolver_endpoint_id not in [ x.id for x in self.resolver_endpoints.values() ]: raise ResourceNotFoundException( f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist." ) if rule_type == "SYSTEM": raise InvalidRequestException( "Cannot specify resolver endpoint ID and target IP " "for SYSTEM type resolver rule" ) if creator_request_id in [ x.creator_request_id for x in self.resolver_rules.values() ]: raise ResourceExistsException( f"Resolver rule with creator request ID " f"'{creator_request_id}' already exists" ) rule_id = f"rslvr-rr-{get_random_hex(17)}" resolver_rule = ResolverRule( region, rule_id, creator_request_id, rule_type, domain_name, target_ips, resolver_endpoint_id, name, ) self.resolver_rules[rule_id] = resolver_rule self.tagger.tag_resource(resolver_rule.arn, tags or []) return resolver_rule def _validate_resolver_endpoint_id(self, resolver_endpoint_id): """Raise an exception if the id is invalid or unknown.""" validate_args([("resolverEndpointId", resolver_endpoint_id)]) if resolver_endpoint_id not in self.resolver_endpoints: raise ResourceNotFoundException( f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist" ) def delete_resolver_endpoint(self, resolver_endpoint_id): self._validate_resolver_endpoint_id(resolver_endpoint_id) # Can't delete an endpoint if there are rules associated with it. rules = [ x.id for x in self.resolver_rules.values() if x.resolver_endpoint_id == resolver_endpoint_id ] if rules: raise InvalidRequestException( f"Cannot delete resolver endpoint unless its related resolver " f"rules are deleted. The following rules still exist for " f"this resolver endpoint: {','.join(rules)}" ) self.tagger.delete_all_tags_for_resource(resolver_endpoint_id) resolver_endpoint = self.resolver_endpoints.pop(resolver_endpoint_id) resolver_endpoint.delete_eni() resolver_endpoint.status = "DELETING" resolver_endpoint.status_message = resolver_endpoint.status_message.replace( "Successfully created", "Deleting" ) return resolver_endpoint def _validate_resolver_rule_id(self, resolver_rule_id): """Raise an exception if the id is invalid or unknown.""" validate_args([("resolverRuleId", resolver_rule_id)]) if resolver_rule_id not in self.resolver_rules: raise ResourceNotFoundException( f"Resolver rule with ID '{resolver_rule_id}' does not exist" ) def delete_resolver_rule(self, resolver_rule_id): self._validate_resolver_rule_id(resolver_rule_id) # Can't delete an rule unless VPC's are disassociated. associations = [ x.id for x in self.resolver_rule_associations.values() if x.resolver_rule_id == resolver_rule_id ] if associations: raise ResourceInUseException( "Please disassociate this resolver rule from VPC first " "before deleting" ) self.tagger.delete_all_tags_for_resource(resolver_rule_id) resolver_rule = self.resolver_rules.pop(resolver_rule_id) resolver_rule.status = "DELETING" resolver_rule.status_message = resolver_rule.status_message.replace( "Successfully created", "Deleting" ) return resolver_rule def disassociate_resolver_rule(self, resolver_rule_id, vpc_id): validate_args([("resolverRuleId", resolver_rule_id), ("vPCId", vpc_id)]) # Non-existent rule or vpc ids? if resolver_rule_id not in self.resolver_rules: raise ResourceNotFoundException( f"Resolver rule with ID '{resolver_rule_id}' does not exist" ) # Find the matching association for this rule and vpc. rule_association_id = None for association in self.resolver_rule_associations.values(): if ( resolver_rule_id == association.resolver_rule_id and vpc_id == association.vpc_id ): rule_association_id = association.id break else: raise ResourceNotFoundException( f"Resolver Rule Association between Resolver Rule " f"'{resolver_rule_id}' and VPC '{vpc_id}' does not exist" ) rule_association = self.resolver_rule_associations.pop(rule_association_id) rule_association.status = "DELETING" rule_association.status_message = "Deleting Association" return rule_association def get_resolver_endpoint(self, resolver_endpoint_id): self._validate_resolver_endpoint_id(resolver_endpoint_id) return self.resolver_endpoints[resolver_endpoint_id] def get_resolver_rule(self, resolver_rule_id): """Return info for specified resolver rule.""" self._validate_resolver_rule_id(resolver_rule_id) return self.resolver_rules[resolver_rule_id] def get_resolver_rule_association(self, resolver_rule_association_id): validate_args([("resolverRuleAssociationId", resolver_rule_association_id)]) if resolver_rule_association_id not in self.resolver_rule_associations: raise ResourceNotFoundException( f"ResolverRuleAssociation '{resolver_rule_association_id}' does not Exist" ) return self.resolver_rule_associations[resolver_rule_association_id] @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_endpoint_ip_addresses(self, resolver_endpoint_id): self._validate_resolver_endpoint_id(resolver_endpoint_id) endpoint = self.resolver_endpoints[resolver_endpoint_id] return endpoint.ip_descriptions() @staticmethod def _add_field_name_to_filter(filters): """Convert both styles of filter names to lowercase snake format. "IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count". However, "HostVPCId" doesn't fit the pattern, so that's treated special. """ for rr_filter in filters: filter_name = rr_filter["Name"] if "_" not in filter_name: if "Vpc" in filter_name: filter_name = "WRONG" elif filter_name == "HostVPCId": filter_name = "host_vpc_id" elif filter_name == "VPCId": filter_name = "vpc_id" elif filter_name in ["Type", "TYPE"]: filter_name = "rule_type" elif not filter_name.isupper(): filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name) rr_filter["Field"] = filter_name.lower() @staticmethod def _validate_filters(filters, allowed_filter_names): """Raise exception if filter names are not as expected.""" for rr_filter in filters: if rr_filter["Field"] not in allowed_filter_names: raise InvalidParameterException( f"The filter '{rr_filter['Name']}' is invalid" ) if "Values" not in rr_filter: raise InvalidParameterException( f"No values specified for filter {rr_filter['Name']}" ) @staticmethod def _matches_all_filters(entity, filters): """Return True if this entity has fields matching all the filters.""" for rr_filter in filters: field_value = getattr(entity, rr_filter["Field"]) if isinstance(field_value, list): if not set(field_value).intersection(rr_filter["Values"]): return False elif isinstance(field_value, int): if str(field_value) not in rr_filter["Values"]: return False elif field_value not in rr_filter["Values"]: return False return True @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_endpoints(self, filters): if not filters: filters = [] self._add_field_name_to_filter(filters) self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES) endpoints = [] for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name): if self._matches_all_filters(endpoint, filters): endpoints.append(endpoint) return endpoints @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_rules(self, filters): if not filters: filters = [] self._add_field_name_to_filter(filters) self._validate_filters(filters, ResolverRule.FILTER_NAMES) rules = [] for rule in sorted(self.resolver_rules.values(), key=lambda x: x.name): if self._matches_all_filters(rule, filters): rules.append(rule) return rules @paginate(pagination_model=PAGINATION_MODEL) def list_resolver_rule_associations(self, filters): if not filters: filters = [] self._add_field_name_to_filter(filters) self._validate_filters(filters, ResolverRuleAssociation.FILTER_NAMES) rules = [] for rule in sorted( self.resolver_rule_associations.values(), key=lambda x: x.name ): if self._matches_all_filters(rule, filters): rules.append(rule) return rules def _matched_arn(self, resource_arn): """Given ARN, raise exception if there is no corresponding resource.""" for resolver_endpoint in self.resolver_endpoints.values(): if resolver_endpoint.arn == resource_arn: return for resolver_rule in self.resolver_rules.values(): if resolver_rule.arn == resource_arn: return raise ResourceNotFoundException( f"Resolver endpoint with ID '{resource_arn}' does not exist" ) @paginate(pagination_model=PAGINATION_MODEL) def list_tags_for_resource(self, resource_arn): self._matched_arn(resource_arn) return self.tagger.list_tags_for_resource(resource_arn).get("Tags") def tag_resource(self, resource_arn, tags): self._matched_arn(resource_arn) errmsg = self.tagger.validate_tags( tags, limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT ) if errmsg: raise TagValidationException(errmsg) self.tagger.tag_resource(resource_arn, tags) def untag_resource(self, resource_arn, tag_keys): self._matched_arn(resource_arn) self.tagger.untag_resource_using_names(resource_arn, tag_keys) def update_resolver_endpoint(self, resolver_endpoint_id, name): self._validate_resolver_endpoint_id(resolver_endpoint_id) validate_args([("name", name)]) resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] resolver_endpoint.update_name(name) return resolver_endpoint def associate_resolver_endpoint_ip_address( self, region, resolver_endpoint_id, ip_address ): self._validate_resolver_endpoint_id(resolver_endpoint_id) resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] if not ip_address.get("Ip"): subnet_info = ec2_backends[region].get_all_subnets( subnet_ids=[ip_address.get("SubnetId")] )[0] ip_address["Ip"] = subnet_info.get_available_subnet_ip(self) self._verify_subnet_ips(region, [ip_address], False) resolver_endpoint.associate_ip_address(ip_address) return resolver_endpoint def disassociate_resolver_endpoint_ip_address( self, resolver_endpoint_id, ip_address ): self._validate_resolver_endpoint_id(resolver_endpoint_id) resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] if not (ip_address.get("Ip") or ip_address.get("IpId")): raise InvalidRequestException( "[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address." ) resolver_endpoint.disassociate_ip_address(ip_address) return resolver_endpoint
class FirehoseBackend(BaseBackend): """Implementation of Firehose APIs.""" def __init__(self, region_name=None): self.region_name = region_name self.delivery_streams = {} self.tagger = TaggingService() def reset(self): """Re-initializes all attributes for this instance.""" region_name = self.region_name self.__dict__ = {} self.__init__(region_name) @staticmethod def default_vpc_endpoint_service(service_region, zones): """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "firehose", special_service_name="kinesis-firehose") def create_delivery_stream( self, region, delivery_stream_name, delivery_stream_type, kinesis_stream_source_configuration, delivery_stream_encryption_configuration_input, s3_destination_configuration, extended_s3_destination_configuration, redshift_destination_configuration, elasticsearch_destination_configuration, splunk_destination_configuration, http_endpoint_destination_configuration, tags, ): # pylint: disable=too-many-arguments,too-many-locals,unused-argument """Create a Kinesis Data Firehose delivery stream.""" (destination_name, destination_config) = find_destination_config_in_args(locals()) if delivery_stream_name in self.delivery_streams: raise ResourceInUseException( f"Firehose {delivery_stream_name} under accountId {get_account_id()} " f"already exists") if len(self.delivery_streams) == DeliveryStream.MAX_STREAMS_PER_REGION: raise LimitExceededException( f"You have already consumed your firehose quota of " f"{DeliveryStream.MAX_STREAMS_PER_REGION} hoses. Firehose " f"names: {list(self.delivery_streams.keys())}") # Rule out situations that are not yet implemented. if delivery_stream_encryption_configuration_input: warnings.warn( "A delivery stream with server-side encryption enabled is not " "yet implemented") if destination_name == "Splunk": warnings.warn( "A Splunk destination delivery stream is not yet implemented") if (kinesis_stream_source_configuration and delivery_stream_type != "KinesisStreamAsSource"): raise InvalidArgumentException( "KinesisSourceStreamConfig is only applicable for " "KinesisStreamAsSource stream type") # Validate the tags before proceeding. errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if tags and len(tags) > MAX_TAGS_PER_DELIVERY_STREAM: raise ValidationException( f"1 validation error detected: Value '{tags}' at 'tags' " f"failed to satisify contstraint: Member must have length " f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}") # Create a DeliveryStream instance that will be stored and indexed # by delivery stream name. This instance will update the state and # create the ARN. delivery_stream = DeliveryStream( region, delivery_stream_name, delivery_stream_type, kinesis_stream_source_configuration, destination_name, destination_config, ) self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags or []) self.delivery_streams[delivery_stream_name] = delivery_stream return self.delivery_streams[delivery_stream_name].delivery_stream_arn def delete_delivery_stream(self, delivery_stream_name, allow_force_delete=False): # pylint: disable=unused-argument """Delete a delivery stream and its data. AllowForceDelete option is ignored as we only superficially apply state. """ delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") self.tagger.delete_all_tags_for_resource( delivery_stream.delivery_stream_arn) delivery_stream.delivery_stream_status = "DELETING" self.delivery_streams.pop(delivery_stream_name) def describe_delivery_stream(self, delivery_stream_name, limit, exclusive_start_destination_id): # pylint: disable=unused-argument """Return description of specified delivery stream and its status. Note: the 'limit' and 'exclusive_start_destination_id' parameters are not currently processed/implemented. """ delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") result = {"DeliveryStreamDescription": {"HasMoreDestinations": False}} for attribute, attribute_value in vars(delivery_stream).items(): if not attribute_value: continue # Convert from attribute's snake case to camel case for outgoing # JSON. name = "".join([x.capitalize() for x in attribute.split("_")]) # Fooey ... always an exception to the rule: if name == "DeliveryStreamArn": name = "DeliveryStreamARN" if name != "Destinations": if name == "Source": result["DeliveryStreamDescription"][name] = { "KinesisStreamSourceDescription": attribute_value } else: result["DeliveryStreamDescription"][name] = attribute_value continue result["DeliveryStreamDescription"]["Destinations"] = [] for destination in attribute_value: description = {} for key, value in destination.items(): if key == "destination_id": description["DestinationId"] = value else: description[f"{key}DestinationDescription"] = value result["DeliveryStreamDescription"]["Destinations"].append( description) return result def list_delivery_streams(self, limit, delivery_stream_type, exclusive_start_delivery_stream_name): """Return list of delivery streams in alphabetic order of names.""" result = {"DeliveryStreamNames": [], "HasMoreDeliveryStreams": False} if not self.delivery_streams: return result # If delivery_stream_type is specified, filter out any stream that's # not of that type. stream_list = self.delivery_streams.keys() if delivery_stream_type: stream_list = [ x for x in stream_list if self.delivery_streams[x].delivery_stream_type == delivery_stream_type ] # The list is sorted alphabetically, not alphanumerically. sorted_list = sorted(stream_list) # Determine the limit or number of names to return in the list. limit = limit or DeliveryStream.MAX_STREAMS_PER_REGION # If a starting delivery stream name is given, find the index into # the sorted list, then add one to get the name following it. If the # exclusive_start_delivery_stream_name doesn't exist, it's ignored. start = 0 if exclusive_start_delivery_stream_name: if self.delivery_streams.get(exclusive_start_delivery_stream_name): start = sorted_list.index( exclusive_start_delivery_stream_name) + 1 result["DeliveryStreamNames"] = sorted_list[start:start + limit] if len(sorted_list) > (start + limit): result["HasMoreDeliveryStreams"] = True return result def list_tags_for_delivery_stream(self, delivery_stream_name, exclusive_start_tag_key, limit): """Return list of tags.""" result = {"Tags": [], "HasMoreTags": False} delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") tags = self.tagger.list_tags_for_resource( delivery_stream.delivery_stream_arn)["Tags"] keys = self.tagger.extract_tag_names(tags) # If a starting tag is given and can be found, find the index into # tags, then add one to get the tag following it. start = 0 if exclusive_start_tag_key: if exclusive_start_tag_key in keys: start = keys.index(exclusive_start_tag_key) + 1 limit = limit or MAX_TAGS_PER_DELIVERY_STREAM result["Tags"] = tags[start:start + limit] if len(tags) > (start + limit): result["HasMoreTags"] = True return result def put_record(self, delivery_stream_name, record): """Write a single data record into a Kinesis Data firehose stream.""" result = self.put_record_batch(delivery_stream_name, [record]) return { "RecordId": result["RequestResponses"][0]["RecordId"], "Encrypted": False, } @staticmethod def put_http_records(http_destination, records): """Put records to a HTTP destination.""" # Mostly copied from localstack url = http_destination["EndpointConfiguration"]["Url"] headers = {"Content-Type": "application/json"} record_to_send = { "requestId": str(uuid4()), "timestamp": int(time()), "records": [{ "data": record["Data"] } for record in records], } try: requests.post(url, json=record_to_send, headers=headers) except Exception as exc: # This could be better ... raise RuntimeError( "Firehose PutRecord(Batch) to HTTP destination failed" ) from exc return [{"RecordId": str(uuid4())} for _ in range(len(records))] @staticmethod def _format_s3_object_path(delivery_stream_name, version_id, prefix): """Return a S3 object path in the expected format.""" # Taken from LocalStack's firehose logic, with minor changes. # See https://docs.aws.amazon.com/firehose/latest/dev/basic-deliver.html#s3-object-name # Path prefix pattern: myApp/YYYY/MM/DD/HH/ # Object name pattern: # DeliveryStreamName-DeliveryStreamVersion-YYYY-MM-DD-HH-MM-SS-RandomString prefix = f"{prefix}{'' if prefix.endswith('/') else '/'}" now = datetime.utcnow() return (f"{prefix}{now.strftime('%Y/%m/%d/%H')}/" f"{delivery_stream_name}-{version_id}-" f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{str(uuid4())}") def put_s3_records(self, delivery_stream_name, version_id, s3_destination, records): """Put records to a ExtendedS3 or S3 destination.""" # Taken from LocalStack's firehose logic, with minor changes. bucket_name = s3_destination["BucketARN"].split(":")[-1] prefix = s3_destination.get("Prefix", "") object_path = self._format_s3_object_path(delivery_stream_name, version_id, prefix) batched_data = b"".join([b64decode(r["Data"]) for r in records]) try: s3_backend.put_object(bucket_name, object_path, batched_data) except Exception as exc: # This could be better ... raise RuntimeError( "Firehose PutRecord(Batch to S3 destination failed") from exc return [{"RecordId": str(uuid4())} for _ in range(len(records))] def put_record_batch(self, delivery_stream_name, records): """Write multiple data records into a Kinesis Data firehose stream.""" delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") request_responses = [] for destination in delivery_stream.destinations: if "ExtendedS3" in destination: # ExtendedS3 will be handled like S3,but in the future # this will probably need to be revisited. This destination # must be listed before S3 otherwise both destinations will # be processed instead of just ExtendedS3. request_responses = self.put_s3_records( delivery_stream_name, delivery_stream.version_id, destination["ExtendedS3"], records, ) elif "S3" in destination: request_responses = self.put_s3_records( delivery_stream_name, delivery_stream.version_id, destination["S3"], records, ) elif "HttpEndpoint" in destination: request_responses = self.put_http_records( destination["HttpEndpoint"], records) elif "Elasticsearch" in destination or "Redshift" in destination: # This isn't implmented as these services aren't implemented, # so ignore the data, but return a "proper" response. request_responses = [{ "RecordId": str(uuid4()) } for _ in range(len(records))] return { "FailedPutCount": 0, "Encrypted": False, "RequestResponses": request_responses, } def tag_delivery_stream(self, delivery_stream_name, tags): """Add/update tags for specified delivery stream.""" delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") if len(tags) > MAX_TAGS_PER_DELIVERY_STREAM: raise ValidationException( f"1 validation error detected: Value '{tags}' at 'tags' " f"failed to satisify contstraint: Member must have length " f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}") errmsg = self.tagger.validate_tags(tags) if errmsg: raise ValidationException(errmsg) self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags) def untag_delivery_stream(self, delivery_stream_name, tag_keys): """Removes tags from specified delivery stream.""" delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under account {get_account_id()} " f"not found.") # If a tag key doesn't exist for the stream, boto3 ignores it. self.tagger.untag_resource_using_names( delivery_stream.delivery_stream_arn, tag_keys) def update_destination( self, delivery_stream_name, current_delivery_stream_version_id, destination_id, s3_destination_update, extended_s3_destination_update, s3_backup_mode, redshift_destination_update, elasticsearch_destination_update, splunk_destination_update, http_endpoint_destination_update, ): # pylint: disable=unused-argument,too-many-arguments,too-many-locals """Updates specified destination of specified delivery stream.""" (destination_name, destination_config) = find_destination_config_in_args(locals()) delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( f"Firehose {delivery_stream_name} under accountId " f"{get_account_id()} not found.") if destination_name == "Splunk": warnings.warn( "A Splunk destination delivery stream is not yet implemented") if delivery_stream.version_id != current_delivery_stream_version_id: raise ConcurrentModificationException( f"Cannot update firehose: {delivery_stream_name} since the " f"current version id: {delivery_stream.version_id} and " f"specified version id: {current_delivery_stream_version_id} " f"do not match") destination = {} destination_idx = 0 for destination in delivery_stream.destinations: if destination["destination_id"] == destination_id: break destination_idx += 1 else: raise InvalidArgumentException( "Destination Id {destination_id} not found") # Switching between Amazon ES and other services is not supported. # For an Amazon ES destination, you can only update to another Amazon # ES destination. Same with HTTP. Didn't test Splunk. if (destination_name == "Elasticsearch" and "Elasticsearch" not in destination) or (destination_name == "HttpEndpoint" and "HttpEndpoint" not in destination): raise InvalidArgumentException( f"Changing the destination type to or from {destination_name} " f"is not supported at this time.") # If this is a different type of destination configuration, # the existing configuration is reset first. if destination_name in destination: delivery_stream.destinations[destination_idx][ destination_name].update(destination_config) else: delivery_stream.destinations[destination_idx] = { "destination_id": destination_id, destination_name: destination_config, } # Once S3 is updated to an ExtendedS3 destination, both remain in # the destination. That means when one is updated, the other needs # to be updated as well. The problem is that they don't have the # same fields. if destination_name == "ExtendedS3": delivery_stream.destinations[destination_idx][ "S3"] = create_s3_destination_config(destination_config) elif destination_name == "S3" and "ExtendedS3" in destination: destination["ExtendedS3"] = { k: v for k, v in destination["S3"].items() if k in destination["ExtendedS3"] } # Increment version number and update the timestamp. delivery_stream.version_id = str( int(current_delivery_stream_version_id) + 1) delivery_stream.last_update_timestamp = datetime.now( timezone.utc).isoformat() # Unimplemented: processing of the "S3BackupMode" parameter. Per the # documentation: "You can update a delivery stream to enable Amazon # S3 backup if it is disabled. If backup is enabled, you can't update # the delivery stream to disable it." def lookup_name_from_arn(self, arn): """Given an ARN, return the associated delivery stream name.""" return self.delivery_streams.get(arn.split("/")[-1]) def send_log_event( self, delivery_stream_arn, filter_name, log_group_name, log_stream_name, log_events, ): # pylint: disable=too-many-arguments """Send log events to a S3 bucket after encoding and gzipping it.""" data = { "logEvents": log_events, "logGroup": log_group_name, "logStream": log_stream_name, "messageType": "DATA_MESSAGE", "owner": get_account_id(), "subscriptionFilters": [filter_name], } output = io.BytesIO() with GzipFile(fileobj=output, mode="w") as fhandle: fhandle.write( json.dumps(data, separators=(",", ":")).encode("utf-8")) gzipped_payload = b64encode(output.getvalue()).decode("utf-8") delivery_stream = self.lookup_name_from_arn(delivery_stream_arn) self.put_s3_records( delivery_stream.delivery_stream_name, delivery_stream.version_id, delivery_stream.destinations[0]["S3"], [{ "Data": gzipped_payload }], )
class DirectoryServiceBackend(BaseBackend): """Implementation of DirectoryService APIs.""" def __init__(self, region_name=None): self.region_name = region_name self.directories = {} self.tagger = TaggingService() def reset(self): """Re-initialize all attributes for this instance.""" region_name = self.region_name self.__dict__ = {} self.__init__(region_name) @staticmethod def default_vpc_endpoint_service(service_region, zones): """List of dicts representing default VPC endpoints for this service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "ds" ) @staticmethod def _validate_create_directory_args( name, passwd, size, vpc_settings, description, short_name, ): # pylint: disable=too-many-arguments """Raise exception if create_directory() args don't meet constraints. The error messages are accumulated before the exception is raised. """ error_tuples = [] passwd_pattern = ( r"(?=^.{8,64}$)((?=.*\d)(?=.*[A-Z])(?=.*[a-z])|" r"(?=.*\d)(?=.*[^A-Za-z0-9\s])(?=.*[a-z])|" r"(?=.*[^A-Za-z0-9\s])(?=.*[A-Z])(?=.*[a-z])|" r"(?=.*\d)(?=.*[A-Z])(?=.*[^A-Za-z0-9\s]))^.*" ) if not re.match(passwd_pattern, passwd): # Can't have an odd number of backslashes in a literal. json_pattern = passwd_pattern.replace("\\", r"\\") error_tuples.append( ( "password", passwd, fr"satisfy regular expression pattern: {json_pattern}", ) ) if size.lower() not in ["small", "large"]: error_tuples.append( ("size", size, "satisfy enum value set: [Small, Large]") ) name_pattern = r"^([a-zA-Z0-9]+[\\.-])+([a-zA-Z0-9])+$" if not re.match(name_pattern, name): error_tuples.append( ("name", name, fr"satisfy regular expression pattern: {name_pattern}") ) subnet_id_pattern = r"^(subnet-[0-9a-f]{8}|subnet-[0-9a-f]{17})$" for subnet in vpc_settings["SubnetIds"]: if not re.match(subnet_id_pattern, subnet): error_tuples.append( ( "vpcSettings.subnetIds", subnet, fr"satisfy regular expression pattern: {subnet_id_pattern}", ) ) if description and len(description) > 128: error_tuples.append( ("description", description, "have length less than or equal to 128") ) short_name_pattern = r'^[^\/:*?"<>|.]+[^\/:*?"<>|]*$' if short_name and not re.match(short_name_pattern, short_name): json_pattern = short_name_pattern.replace("\\", r"\\").replace('"', r"\"") error_tuples.append( ( "shortName", short_name, fr"satisfy regular expression pattern: {json_pattern}", ) ) if error_tuples: raise DsValidationException(error_tuples) @staticmethod def _validate_vpc_setting_values(region, vpc_settings): """Raise exception if vpc_settings are invalid. If settings are valid, add AvailabilityZones to vpc_settings. """ if len(vpc_settings["SubnetIds"]) != 2: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones." ) from moto.ec2 import ec2_backends # pylint: disable=import-outside-toplevel # Subnet IDs are checked before the VPC ID. The Subnet IDs must # be valid and in different availability zones. try: subnets = ec2_backends[region].get_all_subnets( subnet_ids=vpc_settings["SubnetIds"] ) except InvalidSubnetIdError as exc: raise InvalidParameterException( "Invalid subnet ID(s). They must correspond to two subnets " "in different Availability Zones." ) from exc regions = [subnet.availability_zone for subnet in subnets] if regions[0] == regions[1]: raise ClientException( "Invalid subnet ID(s). The two subnets must be in " "different Availability Zones." ) vpcs = ec2_backends[region].describe_vpcs() if vpc_settings["VpcId"] not in [x.id for x in vpcs]: raise ClientException("Invalid VPC ID.") vpc_settings["AvailabilityZones"] = regions def create_directory( self, region, name, short_name, password, description, size, vpc_settings, tags ): # pylint: disable=too-many-arguments """Create a fake Simple Ad Directory.""" if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT: raise DirectoryLimitExceededException( f"Directory limit exceeded. A maximum of " f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created" ) # botocore doesn't look for missing vpc_settings, but boto3 does. if not vpc_settings: raise InvalidParameterException("VpcSettings must be specified.") self._validate_create_directory_args( name, password, size, vpc_settings, description, short_name, ) self._validate_vpc_setting_values(region, vpc_settings) errmsg = self.tagger.validate_tags(tags or []) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( name, password, size, vpc_settings, directory_type="SimpleAD", short_name=short_name, description=description, ) self.directories[directory.directory_id] = directory self.tagger.tag_resource(directory.directory_id, tags or []) return directory.directory_id def _validate_directory_id(self, directory_id): """Raise an exception if the directory id is invalid or unknown.""" # Validation of ID takes precedence over a check for its existence. id_pattern = r"^d-[0-9a-f]{10}$" if not re.match(id_pattern, directory_id): raise DsValidationException( [ ( "directoryId", directory_id, fr"satisfy regular expression pattern: {id_pattern}", ) ] ) if directory_id not in self.directories: raise EntityDoesNotExistException( f"Directory {directory_id} does not exist" ) def delete_directory(self, directory_id): """Delete directory with the matching ID.""" self._validate_directory_id(directory_id) self.tagger.delete_all_tags_for_resource(directory_id) self.directories.pop(directory_id) return directory_id @paginate(pagination_model=PAGINATION_MODEL) def describe_directories( self, directory_ids=None, next_token=None, limit=0 ): # pylint: disable=unused-argument """Return info on all directories or directories with matching IDs.""" for directory_id in directory_ids or self.directories: self._validate_directory_id(directory_id) directories = list(self.directories.values()) if directory_ids: directories = [x for x in directories if x.directory_id in directory_ids] return sorted(directories, key=lambda x: x.launch_time) def get_directory_limits(self): """Return hard-coded limits for the directories.""" counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0} for directory in self.directories.values(): if directory.directory_type == "SimpleAD": counts["SimpleAD"] += 1 elif directory.directory_type in ["MicrosoftAD", "SharedMicrosoftAD"]: counts["MicrosoftAD"] += 1 elif directory.directory_type == "ADConnector": counts["ConnectedAD"] += 1 return { "CloudOnlyDirectoriesLimit": Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyDirectoriesCurrentCount": counts["SimpleAD"], "CloudOnlyDirectoriesLimitReached": counts["SimpleAD"] == Directory.CLOUDONLY_DIRECTORIES_LIMIT, "CloudOnlyMicrosoftADLimit": Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "CloudOnlyMicrosoftADCurrentCount": counts["MicrosoftAD"], "CloudOnlyMicrosoftADLimitReached": counts["MicrosoftAD"] == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT, "ConnectedDirectoriesLimit": Directory.CONNECTED_DIRECTORIES_LIMIT, "ConnectedDirectoriesCurrentCount": counts["ConnectedAD"], "ConnectedDirectoriesLimitReached": counts["ConnectedAD"] == Directory.CONNECTED_DIRECTORIES_LIMIT, } def add_tags_to_resource(self, resource_id, tags): """Add or overwrite one or more tags for specified directory.""" self._validate_directory_id(resource_id) errmsg = self.tagger.validate_tags(tags) if errmsg: raise ValidationException(errmsg) if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY: raise TagLimitExceededException("Tag limit exceeded") self.tagger.tag_resource(resource_id, tags) def remove_tags_from_resource(self, resource_id, tag_keys): """Removes tags from a directory.""" self._validate_directory_id(resource_id) self.tagger.untag_resource_using_names(resource_id, tag_keys) @paginate(pagination_model=PAGINATION_MODEL) def list_tags_for_resource( self, resource_id, next_token=None, limit=None, ): # pylint: disable=unused-argument """List all tags on a directory.""" self._validate_directory_id(resource_id) return self.tagger.list_tags_for_resource(resource_id).get("Tags")
class KmsBackend(BaseBackend): def __init__(self, region_name, account_id=None): super().__init__(region_name=region_name, account_id=account_id) self.keys = {} self.key_to_aliases = defaultdict(set) self.tagger = TaggingService(key_name="TagKey", value_name="TagValue") @staticmethod def default_vpc_endpoint_service(service_region, zones): """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "kms") def _generate_default_keys(self, alias_name): """Creates default kms keys""" if alias_name in RESERVED_ALIASES: key = self.create_key( None, "ENCRYPT_DECRYPT", "SYMMETRIC_DEFAULT", "Default key", None, self.region_name, ) self.add_alias(key.id, alias_name) return key.id def create_key(self, policy, key_usage, key_spec, description, tags, region, multi_region=False): key = Key(policy, key_usage, key_spec, description, region, multi_region) self.keys[key.id] = key if tags is not None and len(tags) > 0: self.tag_resource(key.id, tags) return key # https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html#mrk-sync-properties # In AWS replicas of a key only share some properties with the original key. Some of those properties get updated # in all replicas automatically if those properties change in the original key. Also, such properties can not be # changed for replicas directly. # # In our implementation with just create a copy of all the properties once without any protection from change, # as the exact implementation is currently infeasible. def replicate_key(self, key_id, replica_region): # Using copy() instead of deepcopy(), as the latter results in exception: # TypeError: cannot pickle '_cffi_backend.FFI' object # Since we only update top level properties, copy() should suffice. replica_key = copy(self.keys[key_id]) replica_key.region = replica_region to_region_backend = kms_backends[replica_region] to_region_backend.keys[replica_key.id] = replica_key def update_key_description(self, key_id, description): key = self.keys[self.get_key_id(key_id)] key.description = description def delete_key(self, key_id): if key_id in self.keys: if key_id in self.key_to_aliases: self.key_to_aliases.pop(key_id) self.tagger.delete_all_tags_for_resource(key_id) return self.keys.pop(key_id) def describe_key(self, key_id) -> Key: # allow the different methods (alias, ARN :key/, keyId, ARN alias) to # describe key not just KeyId key_id = self.get_key_id(key_id) if r"alias/" in str(key_id).lower(): key_id = self.get_key_id_from_alias(key_id) return self.keys[self.get_key_id(key_id)] def list_keys(self): return self.keys.values() @staticmethod def get_key_id(key_id): # Allow use of ARN as well as pure KeyId if key_id.startswith("arn:") and ":key/" in key_id: return key_id.split(":key/")[1] return key_id @staticmethod def get_alias_name(alias_name): # Allow use of ARN as well as alias name if alias_name.startswith("arn:") and ":alias/" in alias_name: return "alias/" + alias_name.split(":alias/")[1] return alias_name def any_id_to_key_id(self, key_id): """Go from any valid key ID to the raw key ID. Acceptable inputs: - raw key ID - key ARN - alias name - alias ARN """ key_id = self.get_alias_name(key_id) key_id = self.get_key_id(key_id) if key_id.startswith("alias/"): key_id = self.get_key_id_from_alias(key_id) return key_id def alias_exists(self, alias_name): for aliases in self.key_to_aliases.values(): if alias_name in aliases: return True return False def add_alias(self, target_key_id, alias_name): self.key_to_aliases[target_key_id].add(alias_name) def delete_alias(self, alias_name): """Delete the alias.""" for aliases in self.key_to_aliases.values(): if alias_name in aliases: aliases.remove(alias_name) def get_all_aliases(self): return self.key_to_aliases def get_key_id_from_alias(self, alias_name): for key_id, aliases in dict(self.key_to_aliases).items(): if alias_name in ",".join(aliases): return key_id if alias_name in RESERVED_ALIASES: key_id = self._generate_default_keys(alias_name) return key_id return None def enable_key_rotation(self, key_id): self.keys[self.get_key_id(key_id)].key_rotation_status = True def disable_key_rotation(self, key_id): self.keys[self.get_key_id(key_id)].key_rotation_status = False def get_key_rotation_status(self, key_id): return self.keys[self.get_key_id(key_id)].key_rotation_status def put_key_policy(self, key_id, policy): self.keys[self.get_key_id(key_id)].policy = policy def get_key_policy(self, key_id): return self.keys[self.get_key_id(key_id)].policy def disable_key(self, key_id): self.keys[key_id].enabled = False self.keys[key_id].key_state = "Disabled" def enable_key(self, key_id): self.keys[key_id].enabled = True self.keys[key_id].key_state = "Enabled" def cancel_key_deletion(self, key_id): self.keys[key_id].key_state = "Disabled" self.keys[key_id].deletion_date = None def schedule_key_deletion(self, key_id, pending_window_in_days): if 7 <= pending_window_in_days <= 30: self.keys[key_id].enabled = False self.keys[key_id].key_state = "PendingDeletion" self.keys[key_id].deletion_date = datetime.now() + timedelta( days=pending_window_in_days) return unix_time(self.keys[key_id].deletion_date) def encrypt(self, key_id, plaintext, encryption_context): key_id = self.any_id_to_key_id(key_id) ciphertext_blob = encrypt( master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context, ) arn = self.keys[key_id].arn return ciphertext_blob, arn def decrypt(self, ciphertext_blob, encryption_context): plaintext, key_id = decrypt( master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context, ) arn = self.keys[key_id].arn return plaintext, arn def re_encrypt( self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context, ): destination_key_id = self.any_id_to_key_id(destination_key_id) plaintext, decrypting_arn = self.decrypt( ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context, ) new_ciphertext_blob, encrypting_arn = self.encrypt( key_id=destination_key_id, plaintext=plaintext, encryption_context=destination_encryption_context, ) return new_ciphertext_blob, decrypting_arn, encrypting_arn def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec): key_id = self.any_id_to_key_id(key_id) if key_spec: # Note: Actual validation of key_spec is done in kms.responses if key_spec == "AES_128": plaintext_len = 16 else: plaintext_len = 32 else: plaintext_len = number_of_bytes plaintext = os.urandom(plaintext_len) ciphertext_blob, arn = self.encrypt( key_id=key_id, plaintext=plaintext, encryption_context=encryption_context) return plaintext, ciphertext_blob, arn def list_resource_tags(self, key_id_or_arn): key_id = self.get_key_id(key_id_or_arn) if key_id in self.keys: return self.tagger.list_tags_for_resource(key_id) raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", ) def tag_resource(self, key_id_or_arn, tags): key_id = self.get_key_id(key_id_or_arn) if key_id in self.keys: self.tagger.tag_resource(key_id, tags) return {} raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", ) def untag_resource(self, key_id_or_arn, tag_names): key_id = self.get_key_id(key_id_or_arn) if key_id in self.keys: self.tagger.untag_resource_using_names(key_id, tag_names) return {} raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", ) def create_grant( self, key_id, grantee_principal, operations, name, constraints, retiring_principal, ): key = self.describe_key(key_id) grant = key.add_grant( name, grantee_principal, operations, constraints=constraints, retiring_principal=retiring_principal, ) return grant.id, grant.token def list_grants(self, key_id, grant_id) -> [Grant]: key = self.describe_key(key_id) return key.list_grants(grant_id) def list_retirable_grants(self, retiring_principal): grants = [] for key in self.keys.values(): grants.extend(key.list_retirable_grants(retiring_principal)) return grants def revoke_grant(self, key_id, grant_id) -> None: key = self.describe_key(key_id) key.revoke_grant(grant_id) def retire_grant(self, key_id, grant_id, grant_token) -> None: if grant_token: for key in self.keys.values(): key.retire_grant_by_token(grant_token) else: key = self.describe_key(key_id) key.retire_grant(grant_id) def __ensure_valid_sign_and_verify_key(self, key: Key): if key.key_usage != "SIGN_VERIFY": raise ValidationException(( "1 validation error detected: Value '{key_id}' at 'KeyId' failed " "to satisfy constraint: Member must point to a key with usage: 'SIGN_VERIFY'" ).format(key_id=key.id)) def __ensure_valid_signing_augorithm(self, key: Key, signing_algorithm): if signing_algorithm not in key.signing_algorithms: raise ValidationException(( "1 validation error detected: Value '{signing_algorithm}' at 'SigningAlgorithm' failed " "to satisfy constraint: Member must satisfy enum value set: " "{valid_sign_algorithms}").format( signing_algorithm=signing_algorithm, valid_sign_algorithms=key.signing_algorithms, )) def sign(self, key_id, message, signing_algorithm): """Sign message using generated private key. - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256 - grant_tokens are not implemented """ key = self.describe_key(key_id) self.__ensure_valid_sign_and_verify_key(key) self.__ensure_valid_signing_augorithm(key, signing_algorithm) # TODO: support more than one hardcoded algorithm based on KeySpec signature = key.private_key.sign( message, padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), hashes.SHA256(), ) return key.arn, signature, signing_algorithm def verify(self, key_id, message, signature, signing_algorithm): """Verify message using public key from generated private key. - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256 - grant_tokens are not implemented """ key = self.describe_key(key_id) self.__ensure_valid_sign_and_verify_key(key) self.__ensure_valid_signing_augorithm(key, signing_algorithm) if signing_algorithm not in key.signing_algorithms: raise ValidationException(( "1 validation error detected: Value '{signing_algorithm}' at 'SigningAlgorithm' failed " "to satisfy constraint: Member must satisfy enum value set: " "{valid_sign_algorithms}").format( signing_algorithm=signing_algorithm, valid_sign_algorithms=key.signing_algorithms, )) public_key = key.private_key.public_key() try: # TODO: support more than one hardcoded algorithm based on KeySpec public_key.verify( signature, message, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), hashes.SHA256(), ) return key.arn, True, signing_algorithm except InvalidSignature: return key.arn, False, signing_algorithm
class ECRBackend(BaseBackend): def __init__(self, region_name): self.region_name = region_name self.repositories: Dict[str, Repository] = {} self.tagger = TaggingService(tagName="tags") def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def _get_repository(self, name, registry_id=None) -> Repository: repo = self.repositories.get(name) reg_id = registry_id or DEFAULT_REGISTRY_ID if not repo or repo.registry_id != reg_id: raise RepositoryNotFoundException(name, reg_id) return repo @staticmethod def _parse_resource_arn(resource_arn) -> EcrRepositoryArn: match = re.match(ECR_REPOSITORY_ARN_PATTERN, resource_arn) if not match: raise InvalidParameterException( "Invalid parameter at 'resourceArn' failed to satisfy constraint: " "'Invalid ARN'") return EcrRepositoryArn(**match.groupdict()) def describe_repositories(self, registry_id=None, repository_names=None): """ maxResults and nextToken not implemented """ if repository_names: for repository_name in repository_names: if repository_name not in self.repositories: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) repositories = [] for repository in self.repositories.values(): # If a registry_id was supplied, ensure this repository matches if registry_id: if repository.registry_id != registry_id: continue # If a list of repository names was supplied, esure this repository # is in that list if repository_names: if repository.name not in repository_names: continue repositories.append(repository.response_object) return repositories def create_repository( self, repository_name, encryption_config, image_scan_config, image_tag_mutablility, tags, ): if self.repositories.get(repository_name): raise RepositoryAlreadyExistsException(repository_name, DEFAULT_REGISTRY_ID) repository = Repository( region_name=self.region_name, repository_name=repository_name, encryption_config=encryption_config, image_scan_config=image_scan_config, image_tag_mutablility=image_tag_mutablility, ) self.repositories[repository_name] = repository self.tagger.tag_resource(repository.arn, tags) return repository def delete_repository(self, repository_name, registry_id=None): repo = self._get_repository(repository_name, registry_id) if repo.images: raise RepositoryNotEmptyException( repository_name, registry_id or DEFAULT_REGISTRY_ID) self.tagger.delete_all_tags_for_resource(repo.arn) return self.repositories.pop(repository_name) def list_images(self, repository_name, registry_id=None): """ maxResults and filtering not implemented """ repository = None found = False if repository_name in self.repositories: repository = self.repositories[repository_name] if registry_id: if repository.registry_id == registry_id: found = True else: found = True if not found: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) images = [] for image in repository.images: images.append(image) return images def describe_images(self, repository_name, registry_id=None, image_ids=None): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) if image_ids: response = set() for image_id in image_ids: found = False for image in repository.images: if ("imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"]) or ( "imageTag" in image_id and image_id["imageTag"] in image.image_tags): found = True response.add(image) if not found: image_id_representation = "{imageDigest:'%s', imageTag:'%s'}" % ( image_id.get("imageDigest", "null"), image_id.get("imageTag", "null"), ) raise ImageNotFoundException( image_id=image_id_representation, repository_name=repository_name, registry_id=registry_id or DEFAULT_REGISTRY_ID, ) else: response = [] for image in repository.images: response.append(image) return response def put_image(self, repository_name, image_manifest, image_tag): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise Exception("{0} is not a repository".format(repository_name)) existing_images = list( filter( lambda x: x.response_object["imageManifest"] == image_manifest, repository.images, )) if not existing_images: # this image is not in ECR yet image = Image(image_tag, image_manifest, repository_name) repository.images.append(image) return image else: # update existing image existing_images[0].update_tag(image_tag) return existing_images[0] def batch_get_image( self, repository_name, registry_id=None, image_ids=None, accepted_media_types=None, ): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) if not image_ids: raise ParamValidationError( msg='Missing required parameter in input: "imageIds"') response = {"images": [], "failures": []} for image_id in image_ids: found = False for image in repository.images: if ("imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"] ) or ("imageTag" in image_id and image.image_tag == image_id["imageTag"]): found = True response["images"].append(image.response_batch_get_image) if not found: response["failures"].append({ "imageId": { "imageTag": image_id.get("imageTag", "null") }, "failureCode": "ImageNotFound", "failureReason": "Requested image not found", }) return response def batch_delete_image(self, repository_name, registry_id=None, image_ids=None): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID) if not image_ids: raise ParamValidationError( msg='Missing required parameter in input: "imageIds"') response = {"imageIds": [], "failures": []} for image_id in image_ids: image_found = False # Is request missing both digest and tag? if "imageDigest" not in image_id and "imageTag" not in image_id: response["failures"].append({ "imageId": {}, "failureCode": "MissingDigestAndTag", "failureReason": "Invalid request parameters: both tag and digest cannot be null", }) continue # If we have a digest, is it valid? if "imageDigest" in image_id: pattern = re.compile(r"^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}") if not pattern.match(image_id.get("imageDigest")): response["failures"].append({ "imageId": { "imageDigest": image_id.get("imageDigest", "null") }, "failureCode": "InvalidImageDigest", "failureReason": "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'", }) continue for num, image in enumerate(repository.images): # Search by matching both digest and tag if "imageDigest" in image_id and "imageTag" in image_id: if (image_id["imageDigest"] == image.get_image_digest() and image_id["imageTag"] in image.image_tags): image_found = True for image_tag in reversed(image.image_tags): repository.images[num].image_tag = image_tag response["imageIds"].append( image.response_batch_delete_image) repository.images[num].remove_tag(image_tag) del repository.images[num] # Search by matching digest elif ("imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"]): image_found = True for image_tag in reversed(image.image_tags): repository.images[num].image_tag = image_tag response["imageIds"].append( image.response_batch_delete_image) repository.images[num].remove_tag(image_tag) del repository.images[num] # Search by matching tag elif ("imageTag" in image_id and image_id["imageTag"] in image.image_tags): image_found = True repository.images[num].image_tag = image_id["imageTag"] response["imageIds"].append( image.response_batch_delete_image) if len(image.image_tags) > 1: repository.images[num].remove_tag(image_id["imageTag"]) else: repository.images.remove(image) if not image_found: failure_response = { "imageId": {}, "failureCode": "ImageNotFound", "failureReason": "Requested image not found", } if "imageDigest" in image_id: failure_response["imageId"][ "imageDigest"] = image_id.get( "imageDigest", "null") if "imageTag" in image_id: failure_response["imageId"]["imageTag"] = image_id.get( "imageTag", "null") response["failures"].append(failure_response) return response def list_tags_for_resource(self, arn): resource = self._parse_resource_arn(arn) repo = self._get_repository(resource.repo_name, resource.account_id) return self.tagger.list_tags_for_resource(repo.arn) def tag_resource(self, arn, tags): resource = self._parse_resource_arn(arn) repo = self._get_repository(resource.repo_name, resource.account_id) self.tagger.tag_resource(repo.arn, tags) return {} def untag_resource(self, arn, tag_keys): resource = self._parse_resource_arn(arn) repo = self._get_repository(resource.repo_name, resource.account_id) self.tagger.untag_resource_using_names(repo.arn, tag_keys) return {} def put_image_tag_mutability(self, registry_id, repository_name, image_tag_mutability): if image_tag_mutability not in ["IMMUTABLE", "MUTABLE"]: raise InvalidParameterException( "Invalid parameter at 'imageTagMutability' failed to satisfy constraint: " "'Member must satisfy enum value set: [IMMUTABLE, MUTABLE]'") repo = self._get_repository(repository_name, registry_id) repo.update(image_tag_mutability=image_tag_mutability) return { "registryId": repo.registry_id, "repositoryName": repository_name, "imageTagMutability": repo.image_tag_mutability, } def put_image_scanning_configuration(self, registry_id, repository_name, image_scan_config): repo = self._get_repository(repository_name, registry_id) repo.update(image_scan_config=image_scan_config) return { "registryId": repo.registry_id, "repositoryName": repository_name, "imageScanningConfiguration": repo.image_scanning_configuration, } def set_repository_policy(self, registry_id, repository_name, policy_text): repo = self._get_repository(repository_name, registry_id) try: iam_policy_document_validator = IAMPolicyDocumentValidator( policy_text) # the repository policy can be defined without a resource field iam_policy_document_validator._validate_resource_exist = lambda: None # the repository policy can have the old version 2008-10-17 iam_policy_document_validator._validate_version = lambda: None iam_policy_document_validator.validate() except MalformedPolicyDocument: raise InvalidParameterException( "Invalid parameter at 'PolicyText' failed to satisfy constraint: " "'Invalid repository policy provided'") repo.policy = policy_text return { "registryId": repo.registry_id, "repositoryName": repository_name, "policyText": repo.policy, } def get_repository_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) if not repo.policy: raise RepositoryPolicyNotFoundException(repository_name, repo.registry_id) return { "registryId": repo.registry_id, "repositoryName": repository_name, "policyText": repo.policy, } def delete_repository_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) policy = repo.policy if not policy: raise RepositoryPolicyNotFoundException(repository_name, repo.registry_id) repo.policy = None return { "registryId": repo.registry_id, "repositoryName": repository_name, "policyText": policy, } def put_lifecycle_policy(self, registry_id, repository_name, lifecycle_policy_text): repo = self._get_repository(repository_name, registry_id) validator = EcrLifecyclePolicyValidator(lifecycle_policy_text) validator.validate() repo.lifecycle_policy = lifecycle_policy_text return { "registryId": repo.registry_id, "repositoryName": repository_name, "lifecyclePolicyText": repo.lifecycle_policy, } def get_lifecycle_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) if not repo.lifecycle_policy: raise LifecyclePolicyNotFoundException(repository_name, repo.registry_id) return { "registryId": repo.registry_id, "repositoryName": repository_name, "lifecyclePolicyText": repo.lifecycle_policy, "lastEvaluatedAt": iso_8601_datetime_without_milliseconds(datetime.utcnow()), } def delete_lifecycle_policy(self, registry_id, repository_name): repo = self._get_repository(repository_name, registry_id) policy = repo.lifecycle_policy if not policy: raise LifecyclePolicyNotFoundException(repository_name, repo.registry_id) repo.lifecycle_policy = None return { "registryId": repo.registry_id, "repositoryName": repository_name, "lifecyclePolicyText": policy, "lastEvaluatedAt": iso_8601_datetime_without_milliseconds(datetime.utcnow()), }