Exemplo n.º 1
0
def test_delete_all_tags_for_resource():
    svc = TaggingService()
    tags = [{"Key": "key_key", "Value": "value_value"}]
    tags2 = [{"Key": "key_key2", "Value": "value_value2"}]
    svc.tag_resource("arn", tags)
    svc.tag_resource("arn", tags2)
    svc.delete_all_tags_for_resource("arn")
    result = svc.list_tags_for_resource("arn")

    {"Tags": []}.should.be.equal(result)
Exemplo n.º 2
0
class DirectoryServiceBackend(BaseBackend):
    """Implementation of DirectoryService APIs."""
    def __init__(self, region_name=None):
        self.region_name = region_name
        self.directories = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "ds")

    @staticmethod
    def _verify_subnets(region, vpc_settings):
        """Verify subnets are valid, else raise an exception.

        If settings are valid, add AvailabilityZones to vpc_settings.
        """
        if len(vpc_settings["SubnetIds"]) != 2:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones.")

        # Subnet IDs are checked before the VPC ID.  The Subnet IDs must
        # be valid and in different availability zones.
        try:
            subnets = ec2_backends[region].get_all_subnets(
                subnet_ids=vpc_settings["SubnetIds"])
        except InvalidSubnetIdError as exc:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones.") from exc

        regions = [subnet.availability_zone for subnet in subnets]
        if regions[0] == regions[1]:
            raise ClientException(
                "Invalid subnet ID(s). The two subnets must be in "
                "different Availability Zones.")

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_settings["VpcId"] not in [x.id for x in vpcs]:
            raise ClientException("Invalid VPC ID.")
        vpc_settings["AvailabilityZones"] = regions

    def connect_directory(
        self,
        region,
        name,
        short_name,
        password,
        description,
        size,
        connect_settings,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Create a fake AD Connector."""
        if len(self.directories) > Directory.CONNECTED_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CONNECTED_DIRECTORIES_LIMIT} directories may be created"
            )

        validate_args([
            ("password", password),
            ("size", size),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            (
                "connectSettings.vpcSettings.subnetIds",
                connect_settings["SubnetIds"],
            ),
            (
                "connectSettings.customerUserName",
                connect_settings["CustomerUserName"],
            ),
            ("connectSettings.customerDnsIps",
             connect_settings["CustomerDnsIps"]),
        ])
        # ConnectSettings and VpcSettings both have a VpcId and Subnets.
        self._verify_subnets(region, connect_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "ADConnector",
            size=size,
            connect_settings=connect_settings,
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def create_directory(self, region, name, short_name, password, description,
                         size, vpc_settings, tags):  # pylint: disable=too-many-arguments
        """Create a fake Simple Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created"
            )

        # botocore doesn't look for missing vpc_settings, but boto3 does.
        if not vpc_settings:
            raise InvalidParameterException("VpcSettings must be specified.")
        validate_args([
            ("password", password),
            ("size", size),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]),
        ])
        self._verify_subnets(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "SimpleAD",
            size=size,
            vpc_settings=vpc_settings,
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def _validate_directory_id(self, directory_id):
        """Raise an exception if the directory id is invalid or unknown."""
        # Validation of ID takes precedence over a check for its existence.
        validate_args([("directoryId", directory_id)])
        if directory_id not in self.directories:
            raise EntityDoesNotExistException(
                f"Directory {directory_id} does not exist")

    def create_alias(self, directory_id, alias):
        """Create and assign an alias to a directory."""
        self._validate_directory_id(directory_id)

        # The default alias name is the same as the directory name.  Check
        # whether this directory was already given an alias.
        directory = self.directories[directory_id]
        if directory.alias != directory_id:
            raise InvalidParameterException(
                "The directory in the request already has an alias. That "
                "alias must be deleted before a new alias can be created.")

        # Is the alias already in use?
        if alias in [x.alias for x in self.directories.values()]:
            raise EntityAlreadyExistsException(
                f"Alias '{alias}' already exists.")
        validate_args([("alias", alias)])

        directory.update_alias(alias)
        return {"DirectoryId": directory_id, "Alias": alias}

    def create_microsoft_ad(
        self,
        region,
        name,
        short_name,
        password,
        description,
        vpc_settings,
        edition,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Create a fake Microsoft Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_MICROSOFT_AD_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_MICROSOFT_AD_LIMIT} directories may be created"
            )

        # boto3 looks for missing vpc_settings for create_microsoft_ad().
        validate_args([
            ("password", password),
            ("edition", edition),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]),
        ])
        self._verify_subnets(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "MicrosoftAD",
            vpc_settings=vpc_settings,
            short_name=short_name,
            description=description,
            edition=edition,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def delete_directory(self, directory_id):
        """Delete directory with the matching ID."""
        self._validate_directory_id(directory_id)
        self.tagger.delete_all_tags_for_resource(directory_id)
        self.directories.pop(directory_id)
        return directory_id

    def disable_sso(self, directory_id, username=None, password=None):
        """Disable single-sign on for a directory."""
        self._validate_directory_id(directory_id)
        validate_args([("ssoPassword", password), ("userName", username)])
        directory = self.directories[directory_id]
        directory.enable_sso(False)

    def enable_sso(self, directory_id, username=None, password=None):
        """Enable single-sign on for a directory."""
        self._validate_directory_id(directory_id)
        validate_args([("ssoPassword", password), ("userName", username)])

        directory = self.directories[directory_id]
        if directory.alias == directory_id:
            raise ClientException(
                f"An alias is required before enabling SSO. DomainId={directory_id}"
            )

        directory = self.directories[directory_id]
        directory.enable_sso(True)

    @paginate(pagination_model=PAGINATION_MODEL)
    def describe_directories(self,
                             directory_ids=None,
                             next_token=None,
                             limit=0):  # pylint: disable=unused-argument
        """Return info on all directories or directories with matching IDs."""
        for directory_id in directory_ids or self.directories:
            self._validate_directory_id(directory_id)

        directories = list(self.directories.values())
        if directory_ids:
            directories = [
                x for x in directories if x.directory_id in directory_ids
            ]
        return sorted(directories, key=lambda x: x.launch_time)

    def get_directory_limits(self):
        """Return hard-coded limits for the directories."""
        counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0}
        for directory in self.directories.values():
            if directory.directory_type == "SimpleAD":
                counts["SimpleAD"] += 1
            elif directory.directory_type in [
                    "MicrosoftAD", "SharedMicrosoftAD"
            ]:
                counts["MicrosoftAD"] += 1
            elif directory.directory_type == "ADConnector":
                counts["ConnectedAD"] += 1

        return {
            "CloudOnlyDirectoriesLimit":
            Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyDirectoriesCurrentCount":
            counts["SimpleAD"],
            "CloudOnlyDirectoriesLimitReached":
            counts["SimpleAD"] == Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyMicrosoftADLimit":
            Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "CloudOnlyMicrosoftADCurrentCount":
            counts["MicrosoftAD"],
            "CloudOnlyMicrosoftADLimitReached":
            counts["MicrosoftAD"] == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "ConnectedDirectoriesLimit":
            Directory.CONNECTED_DIRECTORIES_LIMIT,
            "ConnectedDirectoriesCurrentCount":
            counts["ConnectedAD"],
            "ConnectedDirectoriesLimitReached":
            counts["ConnectedAD"] == Directory.CONNECTED_DIRECTORIES_LIMIT,
        }

    def add_tags_to_resource(self, resource_id, tags):
        """Add or overwrite one or more tags for specified directory."""
        self._validate_directory_id(resource_id)
        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise TagLimitExceededException("Tag limit exceeded")
        self.tagger.tag_resource(resource_id, tags)

    def remove_tags_from_resource(self, resource_id, tag_keys):
        """Removes tags from a directory."""
        self._validate_directory_id(resource_id)
        self.tagger.untag_resource_using_names(resource_id, tag_keys)

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(
        self,
        resource_id,
        next_token=None,
        limit=None,
    ):  # pylint: disable=unused-argument
        """List all tags on a directory."""
        self._validate_directory_id(resource_id)
        return self.tagger.list_tags_for_resource(resource_id).get("Tags")
Exemplo n.º 3
0
class KmsBackend(BaseBackend):
    def __init__(self):
        self.keys = {}
        self.key_to_aliases = defaultdict(set)
        self.tagger = TaggingService(keyName="TagKey", valueName="TagValue")

    def create_key(self, policy, key_usage, customer_master_key_spec,
                   description, tags, region):
        key = Key(policy, key_usage, customer_master_key_spec, description,
                  region)
        self.keys[key.id] = key
        if tags is not None and len(tags) > 0:
            self.tag_resource(key.id, tags)
        return key

    def update_key_description(self, key_id, description):
        key = self.keys[self.get_key_id(key_id)]
        key.description = description

    def delete_key(self, key_id):
        if key_id in self.keys:
            if key_id in self.key_to_aliases:
                self.key_to_aliases.pop(key_id)
            self.tagger.delete_all_tags_for_resource(key_id)

            return self.keys.pop(key_id)

    def describe_key(self, key_id):
        # allow the different methods (alias, ARN :key/, keyId, ARN alias) to
        # describe key not just KeyId
        key_id = self.get_key_id(key_id)
        if r"alias/" in str(key_id).lower():
            key_id = self.get_key_id_from_alias(key_id.split("alias/")[1])
        return self.keys[self.get_key_id(key_id)]

    def list_keys(self):
        return self.keys.values()

    @staticmethod
    def get_key_id(key_id):
        # Allow use of ARN as well as pure KeyId
        if key_id.startswith("arn:") and ":key/" in key_id:
            return key_id.split(":key/")[1]

        return key_id

    @staticmethod
    def get_alias_name(alias_name):
        # Allow use of ARN as well as alias name
        if alias_name.startswith("arn:") and ":alias/" in alias_name:
            return alias_name.split(":alias/")[1]

        return alias_name

    def any_id_to_key_id(self, key_id):
        """Go from any valid key ID to the raw key ID.

        Acceptable inputs:
        - raw key ID
        - key ARN
        - alias name
        - alias ARN
        """
        key_id = self.get_alias_name(key_id)
        key_id = self.get_key_id(key_id)
        if key_id.startswith("alias/"):
            key_id = self.get_key_id_from_alias(key_id)
        return key_id

    def alias_exists(self, alias_name):
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                return True

        return False

    def add_alias(self, target_key_id, alias_name):
        self.key_to_aliases[target_key_id].add(alias_name)

    def delete_alias(self, alias_name):
        """Delete the alias."""
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                aliases.remove(alias_name)

    def get_all_aliases(self):
        return self.key_to_aliases

    def get_key_id_from_alias(self, alias_name):
        for key_id, aliases in dict(self.key_to_aliases).items():
            if alias_name in ",".join(aliases):
                return key_id
        return None

    def enable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = True

    def disable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = False

    def get_key_rotation_status(self, key_id):
        return self.keys[self.get_key_id(key_id)].key_rotation_status

    def put_key_policy(self, key_id, policy):
        self.keys[self.get_key_id(key_id)].policy = policy

    def get_key_policy(self, key_id):
        return self.keys[self.get_key_id(key_id)].policy

    def disable_key(self, key_id):
        self.keys[key_id].enabled = False
        self.keys[key_id].key_state = "Disabled"

    def enable_key(self, key_id):
        self.keys[key_id].enabled = True
        self.keys[key_id].key_state = "Enabled"

    def cancel_key_deletion(self, key_id):
        self.keys[key_id].key_state = "Disabled"
        self.keys[key_id].deletion_date = None

    def schedule_key_deletion(self, key_id, pending_window_in_days):
        if 7 <= pending_window_in_days <= 30:
            self.keys[key_id].enabled = False
            self.keys[key_id].key_state = "PendingDeletion"
            self.keys[key_id].deletion_date = datetime.now() + timedelta(
                days=pending_window_in_days)
            return unix_time(self.keys[key_id].deletion_date)

    def encrypt(self, key_id, plaintext, encryption_context):
        key_id = self.any_id_to_key_id(key_id)

        ciphertext_blob = encrypt(
            master_keys=self.keys,
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return ciphertext_blob, arn

    def decrypt(self, ciphertext_blob, encryption_context):
        plaintext, key_id = decrypt(
            master_keys=self.keys,
            ciphertext_blob=ciphertext_blob,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return plaintext, arn

    def re_encrypt(
        self,
        ciphertext_blob,
        source_encryption_context,
        destination_key_id,
        destination_encryption_context,
    ):
        destination_key_id = self.any_id_to_key_id(destination_key_id)

        plaintext, decrypting_arn = self.decrypt(
            ciphertext_blob=ciphertext_blob,
            encryption_context=source_encryption_context,
        )
        new_ciphertext_blob, encrypting_arn = self.encrypt(
            key_id=destination_key_id,
            plaintext=plaintext,
            encryption_context=destination_encryption_context,
        )
        return new_ciphertext_blob, decrypting_arn, encrypting_arn

    def generate_data_key(self, key_id, encryption_context, number_of_bytes,
                          key_spec, grant_tokens):
        key_id = self.any_id_to_key_id(key_id)

        if key_spec:
            # Note: Actual validation of key_spec is done in kms.responses
            if key_spec == "AES_128":
                plaintext_len = 16
            else:
                plaintext_len = 32
        else:
            plaintext_len = number_of_bytes

        plaintext = os.urandom(plaintext_len)

        ciphertext_blob, arn = self.encrypt(
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context)

        return plaintext, ciphertext_blob, arn

    def list_resource_tags(self, key_id):
        if key_id in self.keys:
            return self.tagger.list_tags_for_resource(key_id)
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def tag_resource(self, key_id, tags):
        if key_id in self.keys:
            self.tagger.tag_resource(key_id, tags)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def untag_resource(self, key_id, tag_names):
        if key_id in self.keys:
            self.tagger.untag_resource_using_names(key_id, tag_names)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )
Exemplo n.º 4
0
class ECRBackend(BaseBackend):
    def __init__(self, region_name):
        self.region_name = region_name
        self.registry_policy = None
        self.replication_config = {"rules": []}
        self.repositories: Dict[str, Repository] = {}
        self.tagger = TaggingService(tag_name="tags")

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """Default VPC endpoint service."""
        docker_endpoint = {
            "AcceptanceRequired":
            False,
            "AvailabilityZones":
            zones,
            "BaseEndpointDnsNames":
            [f"dkr.ecr.{service_region}.vpce.amazonaws.com"],
            "ManagesVpcEndpoints":
            False,
            "Owner":
            "amazon",
            "PrivateDnsName":
            f"*.dkr.ecr.{service_region}.amazonaws.com",
            "PrivateDnsNameVerificationState":
            "verified",
            "PrivateDnsNames": [{
                "PrivateDnsName":
                f"*.dkr.ecr.{service_region}.amazonaws.com"
            }],
            "ServiceId":
            f"vpce-svc-{BaseBackend.vpce_random_number()}",
            "ServiceName":
            f"com.amazonaws.{service_region}.ecr.dkr",
            "ServiceType": [{
                "ServiceType": "Interface"
            }],
            "Tags": [],
            "VpcEndpointPolicySupported":
            True,
        }
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "api.ecr",
            special_service_name="ecr.api") + [docker_endpoint]

    def _get_repository(self, name, registry_id=None) -> Repository:
        repo = self.repositories.get(name)
        reg_id = registry_id or DEFAULT_REGISTRY_ID

        if not repo or repo.registry_id != reg_id:
            raise RepositoryNotFoundException(name, reg_id)
        return repo

    @staticmethod
    def _parse_resource_arn(resource_arn) -> EcrRepositoryArn:
        match = re.match(ECR_REPOSITORY_ARN_PATTERN, resource_arn)
        if not match:
            raise InvalidParameterException(
                "Invalid parameter at 'resourceArn' failed to satisfy constraint: "
                "'Invalid ARN'")
        return EcrRepositoryArn(**match.groupdict())

    def describe_repositories(self, registry_id=None, repository_names=None):
        """
        maxResults and nextToken not implemented
        """
        if repository_names:
            for repository_name in repository_names:
                if repository_name not in self.repositories:
                    raise RepositoryNotFoundException(
                        repository_name, registry_id or DEFAULT_REGISTRY_ID)

        repositories = []
        for repository in self.repositories.values():
            # If a registry_id was supplied, ensure this repository matches
            if registry_id:
                if repository.registry_id != registry_id:
                    continue
            # If a list of repository names was supplied, esure this repository
            # is in that list
            if repository_names:
                if repository.name not in repository_names:
                    continue
            repositories.append(repository.response_object)
        return repositories

    def create_repository(
        self,
        repository_name,
        registry_id,
        encryption_config,
        image_scan_config,
        image_tag_mutablility,
        tags,
    ):
        if self.repositories.get(repository_name):
            raise RepositoryAlreadyExistsException(repository_name,
                                                   DEFAULT_REGISTRY_ID)

        repository = Repository(
            region_name=self.region_name,
            repository_name=repository_name,
            registry_id=registry_id,
            encryption_config=encryption_config,
            image_scan_config=image_scan_config,
            image_tag_mutablility=image_tag_mutablility,
        )
        self.repositories[repository_name] = repository
        self.tagger.tag_resource(repository.arn, tags)

        return repository

    def delete_repository(self,
                          repository_name,
                          registry_id=None,
                          force=False):
        repo = self._get_repository(repository_name, registry_id)

        if repo.images and not force:
            raise RepositoryNotEmptyException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        self.tagger.delete_all_tags_for_resource(repo.arn)
        return self.repositories.pop(repository_name)

    def list_images(self, repository_name, registry_id=None):
        """
        maxResults and filtering not implemented
        """
        repository = None
        found = False
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
            if registry_id:
                if repository.registry_id == registry_id:
                    found = True
            else:
                found = True

        if not found:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        images = []
        for image in repository.images:
            images.append(image)
        return images

    def describe_images(self,
                        repository_name,
                        registry_id=None,
                        image_ids=None):
        repository = self._get_repository(repository_name, registry_id)

        if image_ids:
            response = set(
                repository._get_image(image_id.get("imageTag"),
                                      image_id.get("imageDigest"))
                for image_id in image_ids)

        else:
            response = []
            for image in repository.images:
                response.append(image)

        return response

    def put_image(self, repository_name, image_manifest, image_tag):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise Exception("{0} is not a repository".format(repository_name))

        existing_images = list(
            filter(
                lambda x: x.response_object["imageManifest"] == image_manifest,
                repository.images,
            ))
        if not existing_images:
            # this image is not in ECR yet
            image = Image(image_tag, image_manifest, repository_name)
            repository.images.append(image)
            return image
        else:
            # update existing image
            existing_images[0].update_tag(image_tag)
            return existing_images[0]

    def batch_get_image(self,
                        repository_name,
                        registry_id=None,
                        image_ids=None):
        """
        The parameter AcceptedMediaTypes has not yet been implemented
        """
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"images": [], "failures": []}

        for image_id in image_ids:
            found = False
            for image in repository.images:
                if ("imageDigest" in image_id
                        and image.get_image_digest() == image_id["imageDigest"]
                    ) or ("imageTag" in image_id
                          and image.image_tag == image_id["imageTag"]):
                    found = True
                    response["images"].append(image.response_batch_get_image)

        if not found:
            response["failures"].append({
                "imageId": {
                    "imageTag": image_id.get("imageTag", "null")
                },
                "failureCode":
                "ImageNotFound",
                "failureReason":
                "Requested image not found",
            })

        return response

    def batch_delete_image(self,
                           repository_name,
                           registry_id=None,
                           image_ids=None):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"imageIds": [], "failures": []}

        for image_id in image_ids:
            image_found = False

            # Is request missing both digest and tag?
            if "imageDigest" not in image_id and "imageTag" not in image_id:
                response["failures"].append({
                    "imageId": {},
                    "failureCode":
                    "MissingDigestAndTag",
                    "failureReason":
                    "Invalid request parameters: both tag and digest cannot be null",
                })
                continue

            # If we have a digest, is it valid?
            if "imageDigest" in image_id:
                pattern = re.compile(r"^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}")
                if not pattern.match(image_id.get("imageDigest")):
                    response["failures"].append({
                        "imageId": {
                            "imageDigest": image_id.get("imageDigest", "null")
                        },
                        "failureCode":
                        "InvalidImageDigest",
                        "failureReason":
                        "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'",
                    })
                    continue

            for num, image in enumerate(repository.images):

                # Search by matching both digest and tag
                if "imageDigest" in image_id and "imageTag" in image_id:
                    if (image_id["imageDigest"] == image.get_image_digest()
                            and image_id["imageTag"] in image.image_tags):
                        image_found = True
                        for image_tag in reversed(image.image_tags):
                            repository.images[num].image_tag = image_tag
                            response["imageIds"].append(
                                image.response_batch_delete_image)
                            repository.images[num].remove_tag(image_tag)
                        del repository.images[num]

                # Search by matching digest
                elif ("imageDigest" in image_id
                      and image.get_image_digest() == image_id["imageDigest"]):
                    image_found = True
                    for image_tag in reversed(image.image_tags):
                        repository.images[num].image_tag = image_tag
                        response["imageIds"].append(
                            image.response_batch_delete_image)
                        repository.images[num].remove_tag(image_tag)
                    del repository.images[num]

                # Search by matching tag
                elif ("imageTag" in image_id
                      and image_id["imageTag"] in image.image_tags):
                    image_found = True
                    repository.images[num].image_tag = image_id["imageTag"]
                    response["imageIds"].append(
                        image.response_batch_delete_image)
                    if len(image.image_tags) > 1:
                        repository.images[num].remove_tag(image_id["imageTag"])
                    else:
                        repository.images.remove(image)

            if not image_found:
                failure_response = {
                    "imageId": {},
                    "failureCode": "ImageNotFound",
                    "failureReason": "Requested image not found",
                }

                if "imageDigest" in image_id:
                    failure_response["imageId"]["imageDigest"] = image_id.get(
                        "imageDigest", "null")

                if "imageTag" in image_id:
                    failure_response["imageId"]["imageTag"] = image_id.get(
                        "imageTag", "null")

                response["failures"].append(failure_response)

        return response

    def list_tags_for_resource(self, arn):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)

        return self.tagger.list_tags_for_resource(repo.arn)

    def tag_resource(self, arn, tags):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.tag_resource(repo.arn, tags)

        return {}

    def untag_resource(self, arn, tag_keys):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.untag_resource_using_names(repo.arn, tag_keys)

        return {}

    def put_image_tag_mutability(self, registry_id, repository_name,
                                 image_tag_mutability):
        if image_tag_mutability not in ["IMMUTABLE", "MUTABLE"]:
            raise InvalidParameterException(
                "Invalid parameter at 'imageTagMutability' failed to satisfy constraint: "
                "'Member must satisfy enum value set: [IMMUTABLE, MUTABLE]'")

        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_tag_mutability=image_tag_mutability)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageTagMutability": repo.image_tag_mutability,
        }

    def put_image_scanning_configuration(self, registry_id, repository_name,
                                         image_scan_config):
        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_scan_config=image_scan_config)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageScanningConfiguration": repo.image_scanning_configuration,
        }

    def set_repository_policy(self, registry_id, repository_name, policy_text):
        repo = self._get_repository(repository_name, registry_id)

        try:
            iam_policy_document_validator = IAMPolicyDocumentValidator(
                policy_text)
            # the repository policy can be defined without a resource field
            iam_policy_document_validator._validate_resource_exist = lambda: None
            # the repository policy can have the old version 2008-10-17
            iam_policy_document_validator._validate_version = lambda: None
            iam_policy_document_validator.validate()
        except MalformedPolicyDocument:
            raise InvalidParameterException(
                "Invalid parameter at 'PolicyText' failed to satisfy constraint: "
                "'Invalid repository policy provided'")

        repo.policy = policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def get_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def delete_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.policy

        if not policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        repo.policy = None

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": policy,
        }

    def put_lifecycle_policy(self, registry_id, repository_name,
                             lifecycle_policy_text):
        repo = self._get_repository(repository_name, registry_id)

        validator = EcrLifecyclePolicyValidator(lifecycle_policy_text)
        validator.validate()

        repo.lifecycle_policy = lifecycle_policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "lifecyclePolicyText": repo.lifecycle_policy,
        }

    def get_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.lifecycle_policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            repo.lifecycle_policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }

    def delete_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.lifecycle_policy

        if not policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        repo.lifecycle_policy = None

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }

    def _validate_registry_policy_action(self, policy_text):
        # only CreateRepository & ReplicateImage actions are allowed
        VALID_ACTIONS = {"ecr:CreateRepository", "ecr:ReplicateImage"}

        policy = json.loads(policy_text)
        for statement in policy["Statement"]:
            action = statement["Action"]
            if isinstance(action, str):
                action = [action]
            if set(action) - VALID_ACTIONS:
                raise MalformedPolicyDocument()

    def put_registry_policy(self, policy_text):
        try:
            iam_policy_document_validator = IAMPolicyDocumentValidator(
                policy_text)
            iam_policy_document_validator.validate()

            self._validate_registry_policy_action(policy_text)
        except MalformedPolicyDocument:
            raise InvalidParameterException(
                "Invalid parameter at 'PolicyText' failed to satisfy constraint: "
                "'Invalid registry policy provided'")

        self.registry_policy = policy_text

        return {
            "registryId": get_account_id(),
            "policyText": policy_text,
        }

    def get_registry_policy(self):
        if not self.registry_policy:
            raise RegistryPolicyNotFoundException(get_account_id())

        return {
            "registryId": get_account_id(),
            "policyText": self.registry_policy,
        }

    def delete_registry_policy(self):
        policy = self.registry_policy
        if not policy:
            raise RegistryPolicyNotFoundException(get_account_id())

        self.registry_policy = None

        return {
            "registryId": get_account_id(),
            "policyText": policy,
        }

    def start_image_scan(self, registry_id, repository_name, image_id):
        repo = self._get_repository(repository_name, registry_id)

        image = repo._get_image(image_id.get("imageTag"),
                                image_id.get("imageDigest"))

        # scanning an image is only allowed once per day
        if image.last_scan and image.last_scan.date() == datetime.today().date(
        ):
            raise LimitExceededException()

        image.last_scan = datetime.today()

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageId": {
                "imageDigest": image.image_digest,
                "imageTag": image.image_tag,
            },
            "imageScanStatus": {
                "status": "IN_PROGRESS"
            },
        }

    def describe_image_scan_findings(self, registry_id, repository_name,
                                     image_id):
        repo = self._get_repository(repository_name, registry_id)

        image = repo._get_image(image_id.get("imageTag"),
                                image_id.get("imageDigest"))

        if not image.last_scan:
            image_id_rep = "{{imageDigest:'{0}', imageTag:'{1}'}}".format(
                image_id.get("imageDigest") or "null",
                image_id.get("imageTag") or "null",
            )
            raise ScanNotFoundException(
                image_id=image_id_rep,
                repository_name=repository_name,
                registry_id=repo.registry_id,
            )

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageId": {
                "imageDigest": image.image_digest,
                "imageTag": image.image_tag,
            },
            "imageScanStatus": {
                "status": "COMPLETE",
                "description": "The scan was completed successfully.",
            },
            "imageScanFindings": {
                "imageScanCompletedAt":
                iso_8601_datetime_without_milliseconds(image.last_scan),
                "vulnerabilitySourceUpdatedAt":
                iso_8601_datetime_without_milliseconds(datetime.utcnow()),
                "findings": [{
                    "name":
                    "CVE-9999-9999",
                    "uri":
                    "https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-9999-9999",
                    "severity":
                    "HIGH",
                    "attributes": [
                        {
                            "key": "package_version",
                            "value": "9.9.9"
                        },
                        {
                            "key": "package_name",
                            "value": "moto_fake"
                        },
                        {
                            "key": "CVSS2_VECTOR",
                            "value": "AV:N/AC:L/Au:N/C:P/I:P/A:P",
                        },
                        {
                            "key": "CVSS2_SCORE",
                            "value": "7.5"
                        },
                    ],
                }],
                "findingSeverityCounts": {
                    "HIGH": 1
                },
            },
        }

    def put_replication_configuration(self, replication_config):
        rules = replication_config["rules"]
        if len(rules) > 1:
            raise ValidationException("This feature is disabled")

        if len(rules) == 1:
            for dest in rules[0]["destinations"]:
                if (dest["region"] == self.region_name
                        and dest["registryId"] == DEFAULT_REGISTRY_ID):
                    raise InvalidParameterException(
                        "Invalid parameter at 'replicationConfiguration' failed to satisfy constraint: "
                        "'Replication destination cannot be the same as the source registry'"
                    )

        self.replication_config = replication_config

        return {"replicationConfiguration": replication_config}

    def describe_registry(self):
        return {
            "registryId": DEFAULT_REGISTRY_ID,
            "replicationConfiguration": self.replication_config,
        }
Exemplo n.º 5
0
class EventsBackend(BaseBackend):
    ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$")
    STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$")

    def __init__(self, region_name):
        self.rules = {}
        # This array tracks the order in which the rules have been added, since
        # 2.6 doesn't have OrderedDicts.
        self.rules_order = []
        self.next_tokens = {}
        self.region_name = region_name
        self.event_buses = {}
        self.event_sources = {}
        self.archives = {}
        self.replays = {}
        self.tagger = TaggingService()

        self._add_default_event_bus()

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def _add_default_event_bus(self):
        self.event_buses["default"] = EventBus(self.region_name, "default")

    def _get_rule_by_index(self, i):
        return self.rules.get(self.rules_order[i])

    def _gen_next_token(self, index):
        token = os.urandom(128).encode("base64")
        self.next_tokens[token] = index
        return token

    def _process_token_and_limits(self,
                                  array_len,
                                  next_token=None,
                                  limit=None):
        start_index = 0
        end_index = array_len
        new_next_token = None

        if next_token:
            start_index = self.next_tokens.pop(next_token, 0)

        if limit is not None:
            new_end_index = start_index + int(limit)
            if new_end_index < end_index:
                end_index = new_end_index
                new_next_token = self._gen_next_token(end_index)

        return start_index, end_index, new_next_token

    def _get_event_bus(self, name):
        event_bus_name = name.split("/")[-1]

        event_bus = self.event_buses.get(event_bus_name)
        if not event_bus:
            raise ResourceNotFoundException(
                "Event bus {} does not exist.".format(event_bus_name))

        return event_bus

    def _get_replay(self, name):
        replay = self.replays.get(name)
        if not replay:
            raise ResourceNotFoundException(
                "Replay {} does not exist.".format(name))

        return replay

    def delete_rule(self, name):
        self.rules_order.pop(self.rules_order.index(name))
        arn = self.rules.get(name).arn
        if self.tagger.has_tags(arn):
            self.tagger.delete_all_tags_for_resource(arn)
        return self.rules.pop(name) is not None

    def describe_rule(self, name):
        return self.rules.get(name)

    def disable_rule(self, name):
        if name in self.rules:
            self.rules[name].disable()
            return True

        return False

    def enable_rule(self, name):
        if name in self.rules:
            self.rules[name].enable()
            return True

        return False

    def list_rule_names_by_target(self,
                                  target_arn,
                                  next_token=None,
                                  limit=None):
        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            for target in rule.targets:
                if target["Arn"] == target_arn:
                    matching_rules.append(rule.name)

        return_obj["RuleNames"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_rules(self, prefix=None, next_token=None, limit=None):
        match_string = ".*"
        if prefix is not None:
            match_string = "^" + prefix + match_string

        match_regex = re.compile(match_string)

        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            if match_regex.match(rule.name):
                matching_rules.append(rule)

        return_obj["Rules"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_targets_by_rule(self, rule, next_token=None, limit=None):
        # We'll let a KeyError exception be thrown for response to handle if
        # rule doesn't exist.
        rule = self.rules[rule]

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(rule.targets), next_token, limit)

        returned_targets = []
        return_obj = {}

        for i in range(start_index, end_index):
            returned_targets.append(rule.targets[i])

        return_obj["Targets"] = returned_targets
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def update_rule(self, rule, **kwargs):
        rule.event_pattern = kwargs.get("EventPattern") or rule.event_pattern
        rule.schedule_exp = kwargs.get(
            "ScheduleExpression") or rule.schedule_exp
        rule.state = kwargs.get("State") or rule.state
        rule.description = kwargs.get("Description") or rule.description
        rule.role_arn = kwargs.get("RoleArn") or rule.role_arn
        rule.event_bus_name = kwargs.get("EventBusName") or rule.event_bus_name

    def put_rule(self, name, **kwargs):
        if kwargs.get("ScheduleExpression"
                      ) and kwargs.get("EventBusName") != "default":
            raise ValidationException(
                "ScheduleExpression is supported only on the default event bus."
            )

        if name in self.rules:
            self.update_rule(self.rules[name], **kwargs)
            new_rule = self.rules[name]
        else:
            new_rule = Rule(name, self.region_name, **kwargs)
            self.rules[new_rule.name] = new_rule
            self.rules_order.append(new_rule.name)
        return new_rule

    def put_targets(self, name, event_bus_name, targets):
        # super simple ARN check
        invalid_arn = next(
            (target["Arn"] for target in targets
             if not re.match(r"arn:[\d\w:\-/]*", target["Arn"])),
            None,
        )
        if invalid_arn:
            raise ValidationException(
                "Parameter {} is not valid. "
                "Reason: Provided Arn is not in correct format.".format(
                    invalid_arn))

        for target in targets:
            arn = target["Arn"]

            if (":sqs:" in arn and arn.endswith(".fifo")
                    and not target.get("SqsParameters")):
                raise ValidationException(
                    "Parameter(s) SqsParameters must be specified for target: {}."
                    .format(target["Id"]))

        rule = self.rules.get(name)

        if not rule:
            raise ResourceNotFoundException(
                "Rule {0} does not exist on EventBus {1}.".format(
                    name, event_bus_name))

        rule.put_targets(targets)

    def put_events(self, events):
        num_events = len(events)

        if num_events > 10:
            # the exact error text is longer, the Value list consists of all the put events
            raise ValidationException(
                "1 validation error detected: "
                "Value '[PutEventsRequestEntry]' at 'entries' failed to satisfy constraint: "
                "Member must have length less than or equal to 10")

        entries = []
        for event in events:
            if "Source" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Source is not valid. Reason: Source is a required argument.",
                })
            elif "DetailType" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter DetailType is not valid. Reason: DetailType is a required argument.",
                })
            elif "Detail" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Detail is not valid. Reason: Detail is a required argument.",
                })
            else:
                try:
                    json.loads(event["Detail"])
                except ValueError:  # json.JSONDecodeError exists since Python 3.5
                    entries.append({
                        "ErrorCode": "MalformedDetail",
                        "ErrorMessage": "Detail is malformed.",
                    })
                    continue

                event_id = str(uuid4())
                entries.append({"EventId": event_id})

                # if 'EventBusName' is not especially set, it will be sent to the default one
                event_bus_name = event.get("EventBusName", "default")

                for rule in self.rules.values():
                    rule.send_to_targets(
                        event_bus_name,
                        {
                            "version": "0",
                            "id": event_id,
                            "detail-type": event["DetailType"],
                            "source": event["Source"],
                            "account": ACCOUNT_ID,
                            "time": event.get("Time",
                                              unix_time(datetime.utcnow())),
                            "region": self.region_name,
                            "resources": event.get("Resources", []),
                            "detail": json.loads(event["Detail"]),
                        },
                    )

        return entries

    def remove_targets(self, name, event_bus_name, ids):
        rule = self.rules.get(name)

        if not rule:
            raise ResourceNotFoundException(
                "Rule {0} does not exist on EventBus {1}.".format(
                    name, event_bus_name))

        rule.remove_targets(ids)

    def test_event_pattern(self):
        raise NotImplementedError()

    def put_permission(self, event_bus_name, action, principal, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if action is None or action != "events:PutEvents":
            raise JsonRESTError(
                "ValidationException",
                "Provided value in parameter 'action' is not supported.",
            )

        if principal is None or self.ACCOUNT_ID.match(principal) is None:
            raise JsonRESTError("InvalidParameterValue",
                                r"Principal must match ^(\d{1,12}|\*)$")

        if statement_id is None or self.STATEMENT_ID.match(
                statement_id) is None:
            raise JsonRESTError(
                "InvalidParameterValue",
                r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$")

        event_bus._permissions[statement_id] = {
            "Action": action,
            "Principal": principal,
        }

    def remove_permission(self, event_bus_name, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if not len(event_bus._permissions):
            raise JsonRESTError("ResourceNotFoundException",
                                "EventBus does not have a policy.")

        if not event_bus._permissions.pop(statement_id, None):
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Statement with the provided id does not exist.",
            )

    def describe_event_bus(self, name):
        if not name:
            name = "default"

        event_bus = self._get_event_bus(name)

        return event_bus

    def create_event_bus(self, name, event_source_name=None):
        if name in self.event_buses:
            raise JsonRESTError(
                "ResourceAlreadyExistsException",
                "Event bus {} already exists.".format(name),
            )

        if not event_source_name and "/" in name:
            raise JsonRESTError("ValidationException",
                                "Event bus name must not contain '/'.")

        if event_source_name and event_source_name not in self.event_sources:
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Event source {} does not exist.".format(event_source_name),
            )

        self.event_buses[name] = EventBus(self.region_name, name)

        return self.event_buses[name]

    def list_event_buses(self, name_prefix):
        if name_prefix:
            return [
                event_bus for event_bus in self.event_buses.values()
                if event_bus.name.startswith(name_prefix)
            ]

        return list(self.event_buses.values())

    def delete_event_bus(self, name):
        if name == "default":
            raise JsonRESTError("ValidationException",
                                "Cannot delete event bus default.")
        self.event_buses.pop(name, None)

    def list_tags_for_resource(self, arn):
        name = arn.split("/")[-1]
        if name in self.rules:
            return self.tagger.list_tags_for_resource(self.rules[name].arn)
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def tag_resource(self, arn, tags):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.tag_resource(self.rules[name].arn, tags)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def untag_resource(self, arn, tag_names):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.untag_resource_using_names(self.rules[name].arn,
                                                   tag_names)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def create_archive(self, name, source_arn, description, event_pattern,
                       retention):
        if len(name) > 48:
            raise ValidationException(
                " 1 validation error detected: "
                "Value '{}' at 'archiveName' failed to satisfy constraint: "
                "Member must have length less than or equal to 48".format(
                    name))

        event_bus = self._get_event_bus(source_arn)

        if name in self.archives:
            raise ResourceAlreadyExistsException(
                "Archive {} already exists.".format(name))

        archive = Archive(self.region_name, name, source_arn, description,
                          event_pattern, retention)

        rule_event_pattern = json.loads(event_pattern or "{}")
        rule_event_pattern["replay-name"] = [{"exists": False}]

        rule = self.put_rule(
            "Events-Archive-{}".format(name), **{
                "EventPattern": json.dumps(rule_event_pattern),
                "EventBusName": event_bus.name,
                "ManagedBy": "prod.vhs.events.aws.internal",
            })
        self.put_targets(
            rule.name,
            rule.event_bus_name,
            [{
                "Id": rule.name,
                "Arn": "arn:aws:events:{}:::".format(self.region_name),
                "InputTransformer": {
                    "InputPathsMap": {},
                    "InputTemplate":
                    json.dumps({
                        "archive-arn":
                        "{0}:{1}".format(archive.arn, archive.uuid),
                        "event":
                        "<aws.events.event.json>",
                        "ingestion-time":
                        "<aws.events.event.ingestion-time>",
                    }),
                },
            }],
        )

        self.archives[name] = archive

        return archive

    def describe_archive(self, name):
        archive = self.archives.get(name)

        if not archive:
            raise ResourceNotFoundException(
                "Archive {} does not exist.".format(name))

        return archive.describe()

    def list_archives(self, name_prefix, source_arn, state):
        if [name_prefix, source_arn, state].count(None) < 2:
            raise ValidationException(
                "At most one filter is allowed for ListArchives. "
                "Use either : State, EventSourceArn, or NamePrefix.")

        if state and state not in Archive.VALID_STATES:
            raise ValidationException(
                "1 validation error detected: "
                "Value '{0}' at 'state' failed to satisfy constraint: "
                "Member must satisfy enum value set: "
                "[{1}]".format(state, ", ".join(Archive.VALID_STATES)))

        if [name_prefix, source_arn, state].count(None) == 3:
            return [
                archive.describe_short() for archive in self.archives.values()
            ]

        result = []

        for archive in self.archives.values():
            if name_prefix and archive.name.startswith(name_prefix):
                result.append(archive.describe_short())
            elif source_arn and archive.source_arn == source_arn:
                result.append(archive.describe_short())
            elif state and archive.state == state:
                result.append(archive.describe_short())

        return result

    def update_archive(self, name, description, event_pattern, retention):
        archive = self.archives.get(name)

        if not archive:
            raise ResourceNotFoundException(
                "Archive {} does not exist.".format(name))

        archive.update(description, event_pattern, retention)

        return {
            "ArchiveArn": archive.arn,
            "CreationTime": archive.creation_time,
            "State": archive.state,
        }

    def delete_archive(self, name):
        archive = self.archives.get(name)

        if not archive:
            raise ResourceNotFoundException(
                "Archive {} does not exist.".format(name))

        archive.delete(self.region_name)

    def start_replay(self, name, description, source_arn, start_time, end_time,
                     destination):
        event_bus_arn = destination["Arn"]
        event_bus_arn_pattern = r"^arn:aws:events:[a-zA-Z0-9-]+:\d{12}:event-bus/"
        if not re.match(event_bus_arn_pattern, event_bus_arn):
            raise ValidationException(
                "Parameter Destination.Arn is not valid. "
                "Reason: Must contain an event bus ARN.")

        self._get_event_bus(event_bus_arn)

        archive_name = source_arn.split("/")[-1]
        archive = self.archives.get(archive_name)
        if not archive:
            raise ValidationException(
                "Parameter EventSourceArn is not valid. "
                "Reason: Archive {} does not exist.".format(archive_name))

        if event_bus_arn != archive.source_arn:
            raise ValidationException(
                "Parameter Destination.Arn is not valid. "
                "Reason: Cross event bus replay is not permitted.")

        if start_time > end_time:
            raise ValidationException(
                "Parameter EventEndTime is not valid. "
                "Reason: EventStartTime must be before EventEndTime.")

        if name in self.replays:
            raise ResourceAlreadyExistsException(
                "Replay {} already exists.".format(name))

        replay = Replay(
            self.region_name,
            name,
            description,
            source_arn,
            start_time,
            end_time,
            destination,
        )

        self.replays[name] = replay

        replay.replay_events(archive)

        return {
            "ReplayArn": replay.arn,
            "ReplayStartTime": replay.start_time,
            "State": ReplayState.STARTING.
            value,  # the replay will be done before returning the response
        }

    def describe_replay(self, name):
        replay = self._get_replay(name)

        return replay.describe()

    def list_replays(self, name_prefix, source_arn, state):
        if [name_prefix, source_arn, state].count(None) < 2:
            raise ValidationException(
                "At most one filter is allowed for ListReplays. "
                "Use either : State, EventSourceArn, or NamePrefix.")

        valid_states = sorted([item.value for item in ReplayState])
        if state and state not in valid_states:
            raise ValidationException(
                "1 validation error detected: "
                "Value '{0}' at 'state' failed to satisfy constraint: "
                "Member must satisfy enum value set: "
                "[{1}]".format(state, ", ".join(valid_states)))

        if [name_prefix, source_arn, state].count(None) == 3:
            return [
                replay.describe_short() for replay in self.replays.values()
            ]

        result = []

        for replay in self.replays.values():
            if name_prefix and replay.name.startswith(name_prefix):
                result.append(replay.describe_short())
            elif source_arn and replay.source_arn == source_arn:
                result.append(replay.describe_short())
            elif state and replay.state == state:
                result.append(replay.describe_short())

        return result

    def cancel_replay(self, name):
        replay = self._get_replay(name)

        # replays in the state 'COMPLETED' can't be canceled,
        # but the implementation is done synchronously,
        # so they are done right after the start
        if replay.state not in [
                ReplayState.STARTING,
                ReplayState.RUNNING,
                ReplayState.COMPLETED,
        ]:
            raise IllegalStatusException(
                "Replay {} is not in a valid state for this operation.".format(
                    name))

        replay.state = ReplayState.CANCELLED

        return {"ReplayArn": replay.arn, "State": ReplayState.CANCELLING.value}
Exemplo n.º 6
0
class EventsBackend(BaseBackend):
    ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$")
    STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$")

    def __init__(self, region_name):
        self.rules = {}
        # This array tracks the order in which the rules have been added, since
        # 2.6 doesn't have OrderedDicts.
        self.rules_order = []
        self.next_tokens = {}
        self.region_name = region_name
        self.event_buses = {}
        self.event_sources = {}
        self.tagger = TaggingService()

        self._add_default_event_bus()

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def _add_default_event_bus(self):
        self.event_buses["default"] = EventBus(self.region_name, "default")

    def _get_rule_by_index(self, i):
        return self.rules.get(self.rules_order[i])

    def _gen_next_token(self, index):
        token = os.urandom(128).encode("base64")
        self.next_tokens[token] = index
        return token

    def _process_token_and_limits(self,
                                  array_len,
                                  next_token=None,
                                  limit=None):
        start_index = 0
        end_index = array_len
        new_next_token = None

        if next_token:
            start_index = self.next_tokens.pop(next_token, 0)

        if limit is not None:
            new_end_index = start_index + int(limit)
            if new_end_index < end_index:
                end_index = new_end_index
                new_next_token = self._gen_next_token(end_index)

        return start_index, end_index, new_next_token

    def delete_rule(self, name):
        self.rules_order.pop(self.rules_order.index(name))
        arn = self.rules.get(name).arn
        if self.tagger.has_tags(arn):
            self.tagger.delete_all_tags_for_resource(arn)
        return self.rules.pop(name) is not None

    def describe_rule(self, name):
        return self.rules.get(name)

    def disable_rule(self, name):
        if name in self.rules:
            self.rules[name].disable()
            return True

        return False

    def enable_rule(self, name):
        if name in self.rules:
            self.rules[name].enable()
            return True

        return False

    def list_rule_names_by_target(self,
                                  target_arn,
                                  next_token=None,
                                  limit=None):
        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            for target in rule.targets:
                if target["Arn"] == target_arn:
                    matching_rules.append(rule.name)

        return_obj["RuleNames"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_rules(self, prefix=None, next_token=None, limit=None):
        match_string = ".*"
        if prefix is not None:
            match_string = "^" + prefix + match_string

        match_regex = re.compile(match_string)

        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            if match_regex.match(rule.name):
                matching_rules.append(rule)

        return_obj["Rules"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_targets_by_rule(self, rule, next_token=None, limit=None):
        # We'll let a KeyError exception be thrown for response to handle if
        # rule doesn't exist.
        rule = self.rules[rule]

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(rule.targets), next_token, limit)

        returned_targets = []
        return_obj = {}

        for i in range(start_index, end_index):
            returned_targets.append(rule.targets[i])

        return_obj["Targets"] = returned_targets
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def put_rule(self, name, **kwargs):
        new_rule = Rule(name, self.region_name, **kwargs)
        self.rules[new_rule.name] = new_rule
        self.rules_order.append(new_rule.name)
        return new_rule

    def put_targets(self, name, targets):
        rule = self.rules.get(name)

        if rule:
            rule.put_targets(targets)
            return True

        return False

    def put_events(self, events):
        num_events = len(events)

        if num_events < 1:
            raise JsonRESTError("ValidationError", "Need at least 1 event")
        elif num_events > 10:
            # the exact error text is longer, the Value list consists of all the put events
            raise ValidationException(
                "1 validation error detected: "
                "Value '[PutEventsRequestEntry]' at 'entries' failed to satisfy constraint: "
                "Member must have length less than or equal to 10")

        entries = []
        for event in events:
            if "Source" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Source is not valid. Reason: Source is a required argument.",
                })
            elif "DetailType" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter DetailType is not valid. Reason: DetailType is a required argument.",
                })
            elif "Detail" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Detail is not valid. Reason: Detail is a required argument.",
                })
            else:
                try:
                    json.loads(event["Detail"])
                except ValueError:  # json.JSONDecodeError exists since Python 3.5
                    entries.append({
                        "ErrorCode": "MalformedDetail",
                        "ErrorMessage": "Detail is malformed.",
                    })
                    continue

                entries.append({"EventId": str(uuid4())})

        # We dont really need to store the events yet
        return entries

    def remove_targets(self, name, ids):
        rule = self.rules.get(name)

        if rule:
            rule.remove_targets(ids)
            return {"FailedEntries": [], "FailedEntryCount": 0}
        else:
            raise JsonRESTError(
                "ResourceNotFoundException",
                "An entity that you specified does not exist",
            )

    def test_event_pattern(self):
        raise NotImplementedError()

    def put_permission(self, event_bus_name, action, principal, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if action is None or action != "events:PutEvents":
            raise JsonRESTError(
                "ValidationException",
                "Provided value in parameter 'action' is not supported.",
            )

        if principal is None or self.ACCOUNT_ID.match(principal) is None:
            raise JsonRESTError("InvalidParameterValue",
                                r"Principal must match ^(\d{1,12}|\*)$")

        if statement_id is None or self.STATEMENT_ID.match(
                statement_id) is None:
            raise JsonRESTError(
                "InvalidParameterValue",
                r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$")

        event_bus._permissions[statement_id] = {
            "Action": action,
            "Principal": principal,
        }

    def remove_permission(self, event_bus_name, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if not len(event_bus._permissions):
            raise JsonRESTError("ResourceNotFoundException",
                                "EventBus does not have a policy.")

        if not event_bus._permissions.pop(statement_id, None):
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Statement with the provided id does not exist.",
            )

    def describe_event_bus(self, name):
        if not name:
            name = "default"

        event_bus = self.event_buses.get(name)

        if not event_bus:
            raise JsonRESTError("ResourceNotFoundException",
                                "Event bus {} does not exist.".format(name))

        return event_bus

    def create_event_bus(self, name, event_source_name=None):
        if name in self.event_buses:
            raise JsonRESTError(
                "ResourceAlreadyExistsException",
                "Event bus {} already exists.".format(name),
            )

        if not event_source_name and "/" in name:
            raise JsonRESTError("ValidationException",
                                "Event bus name must not contain '/'.")

        if event_source_name and event_source_name not in self.event_sources:
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Event source {} does not exist.".format(event_source_name),
            )

        self.event_buses[name] = EventBus(self.region_name, name)

        return self.event_buses[name]

    def list_event_buses(self, name_prefix):
        if name_prefix:
            return [
                event_bus for event_bus in self.event_buses.values()
                if event_bus.name.startswith(name_prefix)
            ]

        return list(self.event_buses.values())

    def delete_event_bus(self, name):
        if name == "default":
            raise JsonRESTError("ValidationException",
                                "Cannot delete event bus default.")
        self.event_buses.pop(name, None)

    def list_tags_for_resource(self, arn):
        name = arn.split("/")[-1]
        if name in self.rules:
            return self.tagger.list_tags_for_resource(self.rules[name].arn)
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def tag_resource(self, arn, tags):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.tag_resource(self.rules[name].arn, tags)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def untag_resource(self, arn, tag_names):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.untag_resource_using_names(self.rules[name].arn,
                                                   tag_names)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))
Exemplo n.º 7
0
class Route53ResolverBackend(BaseBackend):
    """Implementation of Route53Resolver APIs."""

    def __init__(self, region_name, account_id):
        super().__init__(region_name, account_id)
        self.resolver_endpoints = {}  # Key is self-generated ID (endpoint_id)
        self.resolver_rules = {}  # Key is self-generated ID (rule_id)
        self.resolver_rule_associations = {}  # Key is resolver_rule_association_id)
        self.tagger = TaggingService()

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "route53resolver"
        )

    def associate_resolver_rule(self, region, resolver_rule_id, name, vpc_id):
        validate_args(
            [("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)]
        )

        associations = [
            x for x in self.resolver_rule_associations.values() if x.region == region
        ]
        if len(associations) > ResolverRuleAssociation.MAX_RULE_ASSOCIATIONS_PER_REGION:
            # This error message was not verified to be the same for AWS.
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-rule-association'"
            )

        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist."
            )

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_id not in [x.id for x in vpcs]:
            raise InvalidParameterException(f"The vpc ID '{vpc_id}' does not exist")

        # Can't duplicate resolver rule, vpc id associations.
        for association in self.resolver_rule_associations.values():
            if (
                resolver_rule_id == association.resolver_rule_id
                and vpc_id == association.vpc_id
            ):
                raise InvalidRequestException(
                    f"Cannot associate rules with same domain name with same "
                    f"VPC. Conflict with resolver rule '{resolver_rule_id}'"
                )

        rule_association_id = f"rslvr-rrassoc-{get_random_hex(17)}"
        rule_association = ResolverRuleAssociation(
            region, rule_association_id, resolver_rule_id, vpc_id, name
        )
        self.resolver_rule_associations[rule_association_id] = rule_association
        return rule_association

    @staticmethod
    def _verify_subnet_ips(region, ip_addresses, initial=True):
        """Perform additional checks on the IPAddresses.

        NOTE: This does not include IPv6 addresses.
        """
        # only initial endpoint creation requires atleast two ip addresses
        if initial:
            if len(ip_addresses) < 2:
                raise InvalidRequestException(
                    "Resolver endpoint needs to have at least 2 IP addresses"
                )

        subnets = defaultdict(set)
        for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]:
            try:
                subnet_info = ec2_backends[region].get_all_subnets(
                    subnet_ids=[subnet_id]
                )[0]
            except InvalidSubnetIdError as exc:
                raise InvalidParameterException(
                    f"The subnet ID '{subnet_id}' does not exist"
                ) from exc

            # IP in IPv4 CIDR range and not reserved?
            if ip_address(ip_addr) in subnet_info.reserved_ips or ip_address(
                ip_addr
            ) not in ip_network(subnet_info.cidr_block):
                raise InvalidRequestException(
                    f"IP address '{ip_addr}' is either not in subnet "
                    f"'{subnet_id}' CIDR range or is reserved"
                )

            if ip_addr in subnets[subnet_id]:
                raise ResourceExistsException(
                    f"The IP address '{ip_addr}' in subnet '{subnet_id}' is already in use"
                )
            subnets[subnet_id].add(ip_addr)

    @staticmethod
    def _verify_security_group_ids(region, security_group_ids):
        """Perform additional checks on the security groups."""
        if len(security_group_ids) > 10:
            raise InvalidParameterException("Maximum of 10 security groups are allowed")

        for group_id in security_group_ids:
            if not group_id.startswith("sg-"):
                raise InvalidParameterException(
                    f"Malformed security group ID: Invalid id: '{group_id}' "
                    f"(expecting 'sg-...')"
                )
            try:
                ec2_backends[region].describe_security_groups(group_ids=[group_id])
            except InvalidSecurityGroupNotFoundError as exc:
                raise ResourceNotFoundException(
                    f"The security group '{group_id}' does not exist"
                ) from exc

    def create_resolver_endpoint(
        self,
        region,
        creator_request_id,
        name,
        security_group_ids,
        direction,
        ip_addresses,
        tags,
    ):  # pylint: disable=too-many-arguments
        """
        Return description for a newly created resolver endpoint.

        NOTE:  IPv6 IPs are currently not being filtered when
        calculating the create_resolver_endpoint() IpAddresses.
        """
        validate_args(
            [
                ("creatorRequestId", creator_request_id),
                ("direction", direction),
                ("ipAddresses", ip_addresses),
                ("name", name),
                ("securityGroupIds", security_group_ids),
                ("ipAddresses.subnetId", ip_addresses),
            ]
        )
        errmsg = self.tagger.validate_tags(
            tags or [], limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT
        )
        if errmsg:
            raise TagValidationException(errmsg)

        endpoints = [x for x in self.resolver_endpoints.values() if x.region == region]
        if len(endpoints) > ResolverEndpoint.MAX_ENDPOINTS_PER_REGION:
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-endpoints'"
            )

        for x in ip_addresses:
            if not x.get("Ip"):
                subnet_info = ec2_backends[region].get_all_subnets(
                    subnet_ids=[x["SubnetId"]]
                )[0]
                x["Ip"] = subnet_info.get_available_subnet_ip(self)

        self._verify_subnet_ips(region, ip_addresses)
        self._verify_security_group_ids(region, security_group_ids)
        if creator_request_id in [
            x.creator_request_id for x in self.resolver_endpoints.values()
        ]:
            raise ResourceExistsException(
                f"Resolver endpoint with creator request ID "
                f"'{creator_request_id}' already exists"
            )

        endpoint_id = (
            f"rslvr-{'in' if direction == 'INBOUND' else 'out'}-{get_random_hex(17)}"
        )
        resolver_endpoint = ResolverEndpoint(
            region,
            endpoint_id,
            creator_request_id,
            security_group_ids,
            direction,
            ip_addresses,
            name,
        )

        self.resolver_endpoints[endpoint_id] = resolver_endpoint
        self.tagger.tag_resource(resolver_endpoint.arn, tags or [])
        return resolver_endpoint

    def create_resolver_rule(
        self,
        region,
        creator_request_id,
        name,
        rule_type,
        domain_name,
        target_ips,
        resolver_endpoint_id,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Return description for a newly created resolver rule."""
        validate_args(
            [
                ("creatorRequestId", creator_request_id),
                ("ruleType", rule_type),
                ("domainName", domain_name),
                ("name", name),
                *[("targetIps.port", x) for x in target_ips],
                ("resolverEndpointId", resolver_endpoint_id),
            ]
        )
        errmsg = self.tagger.validate_tags(
            tags or [], limit=ResolverRule.MAX_TAGS_PER_RESOLVER_RULE
        )
        if errmsg:
            raise TagValidationException(errmsg)

        rules = [x for x in self.resolver_rules.values() if x.region == region]
        if len(rules) > ResolverRule.MAX_RULES_PER_REGION:
            # Did not verify that this is the actual error message.
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-rules'"
            )

        # Per the AWS documentation and as seen with the AWS console, target
        # ips are only relevant when the value of Rule is FORWARD.  However,
        # boto3 ignores this condition and so shall we.

        for ip_addr in [x["Ip"] for x in target_ips]:
            try:
                # boto3 fails with an InternalServiceException if IPv6
                # addresses are used, which isn't helpful.
                if not isinstance(ip_address(ip_addr), IPv4Address):
                    raise InvalidParameterException(
                        f"Only IPv4 addresses may be used: '{ip_addr}'"
                    )
            except ValueError as exc:
                raise InvalidParameterException(
                    f"Invalid IP address: '{ip_addr}'"
                ) from exc

        # The boto3 documentation indicates that ResolverEndpoint is
        # optional, as does the AWS documention.  But if resolver_endpoint_id
        # is set to None or an empty string, it results in boto3 raising
        # a ParamValidationError either regarding the type or len of string.
        if resolver_endpoint_id:
            if resolver_endpoint_id not in [
                x.id for x in self.resolver_endpoints.values()
            ]:
                raise ResourceNotFoundException(
                    f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist."
                )

            if rule_type == "SYSTEM":
                raise InvalidRequestException(
                    "Cannot specify resolver endpoint ID and target IP "
                    "for SYSTEM type resolver rule"
                )

        if creator_request_id in [
            x.creator_request_id for x in self.resolver_rules.values()
        ]:
            raise ResourceExistsException(
                f"Resolver rule with creator request ID "
                f"'{creator_request_id}' already exists"
            )

        rule_id = f"rslvr-rr-{get_random_hex(17)}"
        resolver_rule = ResolverRule(
            region,
            rule_id,
            creator_request_id,
            rule_type,
            domain_name,
            target_ips,
            resolver_endpoint_id,
            name,
        )

        self.resolver_rules[rule_id] = resolver_rule
        self.tagger.tag_resource(resolver_rule.arn, tags or [])
        return resolver_rule

    def _validate_resolver_endpoint_id(self, resolver_endpoint_id):
        """Raise an exception if the id is invalid or unknown."""
        validate_args([("resolverEndpointId", resolver_endpoint_id)])
        if resolver_endpoint_id not in self.resolver_endpoints:
            raise ResourceNotFoundException(
                f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist"
            )

    def delete_resolver_endpoint(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)

        # Can't delete an endpoint if there are rules associated with it.
        rules = [
            x.id
            for x in self.resolver_rules.values()
            if x.resolver_endpoint_id == resolver_endpoint_id
        ]
        if rules:
            raise InvalidRequestException(
                f"Cannot delete resolver endpoint unless its related resolver "
                f"rules are deleted.  The following rules still exist for "
                f"this resolver endpoint:  {','.join(rules)}"
            )

        self.tagger.delete_all_tags_for_resource(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints.pop(resolver_endpoint_id)
        resolver_endpoint.delete_eni()
        resolver_endpoint.status = "DELETING"
        resolver_endpoint.status_message = resolver_endpoint.status_message.replace(
            "Successfully created", "Deleting"
        )
        return resolver_endpoint

    def _validate_resolver_rule_id(self, resolver_rule_id):
        """Raise an exception if the id is invalid or unknown."""
        validate_args([("resolverRuleId", resolver_rule_id)])
        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist"
            )

    def delete_resolver_rule(self, resolver_rule_id):
        self._validate_resolver_rule_id(resolver_rule_id)

        # Can't delete an rule unless VPC's are disassociated.
        associations = [
            x.id
            for x in self.resolver_rule_associations.values()
            if x.resolver_rule_id == resolver_rule_id
        ]
        if associations:
            raise ResourceInUseException(
                "Please disassociate this resolver rule from VPC first "
                "before deleting"
            )

        self.tagger.delete_all_tags_for_resource(resolver_rule_id)
        resolver_rule = self.resolver_rules.pop(resolver_rule_id)
        resolver_rule.status = "DELETING"
        resolver_rule.status_message = resolver_rule.status_message.replace(
            "Successfully created", "Deleting"
        )
        return resolver_rule

    def disassociate_resolver_rule(self, resolver_rule_id, vpc_id):
        validate_args([("resolverRuleId", resolver_rule_id), ("vPCId", vpc_id)])

        # Non-existent rule or vpc ids?
        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist"
            )

        # Find the matching association for this rule and vpc.
        rule_association_id = None
        for association in self.resolver_rule_associations.values():
            if (
                resolver_rule_id == association.resolver_rule_id
                and vpc_id == association.vpc_id
            ):
                rule_association_id = association.id
                break
        else:
            raise ResourceNotFoundException(
                f"Resolver Rule Association between Resolver Rule "
                f"'{resolver_rule_id}' and VPC '{vpc_id}' does not exist"
            )

        rule_association = self.resolver_rule_associations.pop(rule_association_id)
        rule_association.status = "DELETING"
        rule_association.status_message = "Deleting Association"
        return rule_association

    def get_resolver_endpoint(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        return self.resolver_endpoints[resolver_endpoint_id]

    def get_resolver_rule(self, resolver_rule_id):
        """Return info for specified resolver rule."""
        self._validate_resolver_rule_id(resolver_rule_id)
        return self.resolver_rules[resolver_rule_id]

    def get_resolver_rule_association(self, resolver_rule_association_id):
        validate_args([("resolverRuleAssociationId", resolver_rule_association_id)])
        if resolver_rule_association_id not in self.resolver_rule_associations:
            raise ResourceNotFoundException(
                f"ResolverRuleAssociation '{resolver_rule_association_id}' does not Exist"
            )
        return self.resolver_rule_associations[resolver_rule_association_id]

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_endpoint_ip_addresses(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        endpoint = self.resolver_endpoints[resolver_endpoint_id]
        return endpoint.ip_descriptions()

    @staticmethod
    def _add_field_name_to_filter(filters):
        """Convert both styles of filter names to lowercase snake format.

        "IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count".
        However, "HostVPCId" doesn't fit the pattern, so that's treated
        special.
        """
        for rr_filter in filters:
            filter_name = rr_filter["Name"]
            if "_" not in filter_name:
                if "Vpc" in filter_name:
                    filter_name = "WRONG"
                elif filter_name == "HostVPCId":
                    filter_name = "host_vpc_id"
                elif filter_name == "VPCId":
                    filter_name = "vpc_id"
                elif filter_name in ["Type", "TYPE"]:
                    filter_name = "rule_type"
                elif not filter_name.isupper():
                    filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name)
            rr_filter["Field"] = filter_name.lower()

    @staticmethod
    def _validate_filters(filters, allowed_filter_names):
        """Raise exception if filter names are not as expected."""
        for rr_filter in filters:
            if rr_filter["Field"] not in allowed_filter_names:
                raise InvalidParameterException(
                    f"The filter '{rr_filter['Name']}' is invalid"
                )
            if "Values" not in rr_filter:
                raise InvalidParameterException(
                    f"No values specified for filter {rr_filter['Name']}"
                )

    @staticmethod
    def _matches_all_filters(entity, filters):
        """Return True if this entity has fields matching all the filters."""
        for rr_filter in filters:
            field_value = getattr(entity, rr_filter["Field"])

            if isinstance(field_value, list):
                if not set(field_value).intersection(rr_filter["Values"]):
                    return False
            elif isinstance(field_value, int):
                if str(field_value) not in rr_filter["Values"]:
                    return False
            elif field_value not in rr_filter["Values"]:
                return False

        return True

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_endpoints(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES)

        endpoints = []
        for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name):
            if self._matches_all_filters(endpoint, filters):
                endpoints.append(endpoint)
        return endpoints

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_rules(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverRule.FILTER_NAMES)

        rules = []
        for rule in sorted(self.resolver_rules.values(), key=lambda x: x.name):
            if self._matches_all_filters(rule, filters):
                rules.append(rule)
        return rules

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_rule_associations(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverRuleAssociation.FILTER_NAMES)

        rules = []
        for rule in sorted(
            self.resolver_rule_associations.values(), key=lambda x: x.name
        ):
            if self._matches_all_filters(rule, filters):
                rules.append(rule)
        return rules

    def _matched_arn(self, resource_arn):
        """Given ARN, raise exception if there is no corresponding resource."""
        for resolver_endpoint in self.resolver_endpoints.values():
            if resolver_endpoint.arn == resource_arn:
                return
        for resolver_rule in self.resolver_rules.values():
            if resolver_rule.arn == resource_arn:
                return
        raise ResourceNotFoundException(
            f"Resolver endpoint with ID '{resource_arn}' does not exist"
        )

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(self, resource_arn):
        self._matched_arn(resource_arn)
        return self.tagger.list_tags_for_resource(resource_arn).get("Tags")

    def tag_resource(self, resource_arn, tags):
        self._matched_arn(resource_arn)
        errmsg = self.tagger.validate_tags(
            tags, limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT
        )
        if errmsg:
            raise TagValidationException(errmsg)
        self.tagger.tag_resource(resource_arn, tags)

    def untag_resource(self, resource_arn, tag_keys):
        self._matched_arn(resource_arn)
        self.tagger.untag_resource_using_names(resource_arn, tag_keys)

    def update_resolver_endpoint(self, resolver_endpoint_id, name):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        validate_args([("name", name)])
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
        resolver_endpoint.update_name(name)
        return resolver_endpoint

    def associate_resolver_endpoint_ip_address(
        self, region, resolver_endpoint_id, ip_address
    ):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]

        if not ip_address.get("Ip"):
            subnet_info = ec2_backends[region].get_all_subnets(
                subnet_ids=[ip_address.get("SubnetId")]
            )[0]
            ip_address["Ip"] = subnet_info.get_available_subnet_ip(self)
        self._verify_subnet_ips(region, [ip_address], False)

        resolver_endpoint.associate_ip_address(ip_address)
        return resolver_endpoint

    def disassociate_resolver_endpoint_ip_address(
        self, resolver_endpoint_id, ip_address
    ):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]

        if not (ip_address.get("Ip") or ip_address.get("IpId")):
            raise InvalidRequestException(
                "[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address."
            )

        resolver_endpoint.disassociate_ip_address(ip_address)
        return resolver_endpoint
Exemplo n.º 8
0
class FirehoseBackend(BaseBackend):
    """Implementation of Firehose APIs."""
    def __init__(self, region_name=None):
        self.region_name = region_name
        self.delivery_streams = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initializes all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """Default VPC endpoint service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region,
            zones,
            "firehose",
            special_service_name="kinesis-firehose")

    def create_delivery_stream(
        self,
        region,
        delivery_stream_name,
        delivery_stream_type,
        kinesis_stream_source_configuration,
        delivery_stream_encryption_configuration_input,
        s3_destination_configuration,
        extended_s3_destination_configuration,
        redshift_destination_configuration,
        elasticsearch_destination_configuration,
        splunk_destination_configuration,
        http_endpoint_destination_configuration,
        tags,
    ):  # pylint: disable=too-many-arguments,too-many-locals,unused-argument
        """Create a Kinesis Data Firehose delivery stream."""
        (destination_name,
         destination_config) = find_destination_config_in_args(locals())

        if delivery_stream_name in self.delivery_streams:
            raise ResourceInUseException(
                f"Firehose {delivery_stream_name} under accountId {get_account_id()} "
                f"already exists")

        if len(self.delivery_streams) == DeliveryStream.MAX_STREAMS_PER_REGION:
            raise LimitExceededException(
                f"You have already consumed your firehose quota of "
                f"{DeliveryStream.MAX_STREAMS_PER_REGION} hoses. Firehose "
                f"names: {list(self.delivery_streams.keys())}")

        # Rule out situations that are not yet implemented.
        if delivery_stream_encryption_configuration_input:
            warnings.warn(
                "A delivery stream with server-side encryption enabled is not "
                "yet implemented")

        if destination_name == "Splunk":
            warnings.warn(
                "A Splunk destination delivery stream is not yet implemented")

        if (kinesis_stream_source_configuration
                and delivery_stream_type != "KinesisStreamAsSource"):
            raise InvalidArgumentException(
                "KinesisSourceStreamConfig is only applicable for "
                "KinesisStreamAsSource stream type")

        # Validate the tags before proceeding.
        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)

        if tags and len(tags) > MAX_TAGS_PER_DELIVERY_STREAM:
            raise ValidationException(
                f"1 validation error detected: Value '{tags}' at 'tags' "
                f"failed to satisify contstraint: Member must have length "
                f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}")

        # Create a DeliveryStream instance that will be stored and indexed
        # by delivery stream name.  This instance will update the state and
        # create the ARN.
        delivery_stream = DeliveryStream(
            region,
            delivery_stream_name,
            delivery_stream_type,
            kinesis_stream_source_configuration,
            destination_name,
            destination_config,
        )
        self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags
                                 or [])

        self.delivery_streams[delivery_stream_name] = delivery_stream
        return self.delivery_streams[delivery_stream_name].delivery_stream_arn

    def delete_delivery_stream(self,
                               delivery_stream_name,
                               allow_force_delete=False):  # pylint: disable=unused-argument
        """Delete a delivery stream and its data.

        AllowForceDelete option is ignored as we only superficially
        apply state.
        """
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        self.tagger.delete_all_tags_for_resource(
            delivery_stream.delivery_stream_arn)

        delivery_stream.delivery_stream_status = "DELETING"
        self.delivery_streams.pop(delivery_stream_name)

    def describe_delivery_stream(self, delivery_stream_name, limit,
                                 exclusive_start_destination_id):  # pylint: disable=unused-argument
        """Return description of specified delivery stream and its status.

        Note:  the 'limit' and 'exclusive_start_destination_id' parameters
        are not currently processed/implemented.
        """
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        result = {"DeliveryStreamDescription": {"HasMoreDestinations": False}}
        for attribute, attribute_value in vars(delivery_stream).items():
            if not attribute_value:
                continue

            # Convert from attribute's snake case to camel case for outgoing
            # JSON.
            name = "".join([x.capitalize() for x in attribute.split("_")])

            # Fooey ... always an exception to the rule:
            if name == "DeliveryStreamArn":
                name = "DeliveryStreamARN"

            if name != "Destinations":
                if name == "Source":
                    result["DeliveryStreamDescription"][name] = {
                        "KinesisStreamSourceDescription": attribute_value
                    }
                else:
                    result["DeliveryStreamDescription"][name] = attribute_value
                continue

            result["DeliveryStreamDescription"]["Destinations"] = []
            for destination in attribute_value:
                description = {}
                for key, value in destination.items():
                    if key == "destination_id":
                        description["DestinationId"] = value
                    else:
                        description[f"{key}DestinationDescription"] = value

                result["DeliveryStreamDescription"]["Destinations"].append(
                    description)

        return result

    def list_delivery_streams(self, limit, delivery_stream_type,
                              exclusive_start_delivery_stream_name):
        """Return list of delivery streams in alphabetic order of names."""
        result = {"DeliveryStreamNames": [], "HasMoreDeliveryStreams": False}
        if not self.delivery_streams:
            return result

        # If delivery_stream_type is specified, filter out any stream that's
        # not of that type.
        stream_list = self.delivery_streams.keys()
        if delivery_stream_type:
            stream_list = [
                x for x in stream_list
                if self.delivery_streams[x].delivery_stream_type ==
                delivery_stream_type
            ]

        # The list is sorted alphabetically, not alphanumerically.
        sorted_list = sorted(stream_list)

        # Determine the limit or number of names to return in the list.
        limit = limit or DeliveryStream.MAX_STREAMS_PER_REGION

        # If a starting delivery stream name is given, find the index into
        # the sorted list, then add one to get the name following it.  If the
        # exclusive_start_delivery_stream_name doesn't exist, it's ignored.
        start = 0
        if exclusive_start_delivery_stream_name:
            if self.delivery_streams.get(exclusive_start_delivery_stream_name):
                start = sorted_list.index(
                    exclusive_start_delivery_stream_name) + 1

        result["DeliveryStreamNames"] = sorted_list[start:start + limit]
        if len(sorted_list) > (start + limit):
            result["HasMoreDeliveryStreams"] = True
        return result

    def list_tags_for_delivery_stream(self, delivery_stream_name,
                                      exclusive_start_tag_key, limit):
        """Return list of tags."""
        result = {"Tags": [], "HasMoreTags": False}
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        tags = self.tagger.list_tags_for_resource(
            delivery_stream.delivery_stream_arn)["Tags"]
        keys = self.tagger.extract_tag_names(tags)

        # If a starting tag is given and can be found, find the index into
        # tags, then add one to get the tag following it.
        start = 0
        if exclusive_start_tag_key:
            if exclusive_start_tag_key in keys:
                start = keys.index(exclusive_start_tag_key) + 1

        limit = limit or MAX_TAGS_PER_DELIVERY_STREAM
        result["Tags"] = tags[start:start + limit]
        if len(tags) > (start + limit):
            result["HasMoreTags"] = True
        return result

    def put_record(self, delivery_stream_name, record):
        """Write a single data record into a Kinesis Data firehose stream."""
        result = self.put_record_batch(delivery_stream_name, [record])
        return {
            "RecordId": result["RequestResponses"][0]["RecordId"],
            "Encrypted": False,
        }

    @staticmethod
    def put_http_records(http_destination, records):
        """Put records to a HTTP destination."""
        # Mostly copied from localstack
        url = http_destination["EndpointConfiguration"]["Url"]
        headers = {"Content-Type": "application/json"}
        record_to_send = {
            "requestId": str(uuid4()),
            "timestamp": int(time()),
            "records": [{
                "data": record["Data"]
            } for record in records],
        }
        try:
            requests.post(url, json=record_to_send, headers=headers)
        except Exception as exc:
            # This could be better ...
            raise RuntimeError(
                "Firehose PutRecord(Batch) to HTTP destination failed"
            ) from exc
        return [{"RecordId": str(uuid4())} for _ in range(len(records))]

    @staticmethod
    def _format_s3_object_path(delivery_stream_name, version_id, prefix):
        """Return a S3 object path in the expected format."""
        # Taken from LocalStack's firehose logic, with minor changes.
        # See https://docs.aws.amazon.com/firehose/latest/dev/basic-deliver.html#s3-object-name
        # Path prefix pattern: myApp/YYYY/MM/DD/HH/
        # Object name pattern:
        # DeliveryStreamName-DeliveryStreamVersion-YYYY-MM-DD-HH-MM-SS-RandomString
        prefix = f"{prefix}{'' if prefix.endswith('/') else '/'}"
        now = datetime.utcnow()
        return (f"{prefix}{now.strftime('%Y/%m/%d/%H')}/"
                f"{delivery_stream_name}-{version_id}-"
                f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{str(uuid4())}")

    def put_s3_records(self, delivery_stream_name, version_id, s3_destination,
                       records):
        """Put records to a ExtendedS3 or S3 destination."""
        # Taken from LocalStack's firehose logic, with minor changes.
        bucket_name = s3_destination["BucketARN"].split(":")[-1]
        prefix = s3_destination.get("Prefix", "")
        object_path = self._format_s3_object_path(delivery_stream_name,
                                                  version_id, prefix)

        batched_data = b"".join([b64decode(r["Data"]) for r in records])
        try:
            s3_backend.put_object(bucket_name, object_path, batched_data)
        except Exception as exc:
            # This could be better ...
            raise RuntimeError(
                "Firehose PutRecord(Batch to S3 destination failed") from exc
        return [{"RecordId": str(uuid4())} for _ in range(len(records))]

    def put_record_batch(self, delivery_stream_name, records):
        """Write multiple data records into a Kinesis Data firehose stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        request_responses = []
        for destination in delivery_stream.destinations:
            if "ExtendedS3" in destination:
                # ExtendedS3 will be handled like S3,but in the future
                # this will probably need to be revisited.  This destination
                # must be listed before S3 otherwise both destinations will
                # be processed instead of just ExtendedS3.
                request_responses = self.put_s3_records(
                    delivery_stream_name,
                    delivery_stream.version_id,
                    destination["ExtendedS3"],
                    records,
                )
            elif "S3" in destination:
                request_responses = self.put_s3_records(
                    delivery_stream_name,
                    delivery_stream.version_id,
                    destination["S3"],
                    records,
                )
            elif "HttpEndpoint" in destination:
                request_responses = self.put_http_records(
                    destination["HttpEndpoint"], records)
            elif "Elasticsearch" in destination or "Redshift" in destination:
                # This isn't implmented as these services aren't implemented,
                # so ignore the data, but return a "proper" response.
                request_responses = [{
                    "RecordId": str(uuid4())
                } for _ in range(len(records))]

        return {
            "FailedPutCount": 0,
            "Encrypted": False,
            "RequestResponses": request_responses,
        }

    def tag_delivery_stream(self, delivery_stream_name, tags):
        """Add/update tags for specified delivery stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        if len(tags) > MAX_TAGS_PER_DELIVERY_STREAM:
            raise ValidationException(
                f"1 validation error detected: Value '{tags}' at 'tags' "
                f"failed to satisify contstraint: Member must have length "
                f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}")

        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)

        self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags)

    def untag_delivery_stream(self, delivery_stream_name, tag_keys):
        """Removes tags from specified delivery stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        # If a tag key doesn't exist for the stream, boto3 ignores it.
        self.tagger.untag_resource_using_names(
            delivery_stream.delivery_stream_arn, tag_keys)

    def update_destination(
        self,
        delivery_stream_name,
        current_delivery_stream_version_id,
        destination_id,
        s3_destination_update,
        extended_s3_destination_update,
        s3_backup_mode,
        redshift_destination_update,
        elasticsearch_destination_update,
        splunk_destination_update,
        http_endpoint_destination_update,
    ):  # pylint: disable=unused-argument,too-many-arguments,too-many-locals
        """Updates specified destination of specified delivery stream."""
        (destination_name,
         destination_config) = find_destination_config_in_args(locals())

        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under accountId "
                f"{get_account_id()} not found.")

        if destination_name == "Splunk":
            warnings.warn(
                "A Splunk destination delivery stream is not yet implemented")

        if delivery_stream.version_id != current_delivery_stream_version_id:
            raise ConcurrentModificationException(
                f"Cannot update firehose: {delivery_stream_name} since the "
                f"current version id: {delivery_stream.version_id} and "
                f"specified version id: {current_delivery_stream_version_id} "
                f"do not match")

        destination = {}
        destination_idx = 0
        for destination in delivery_stream.destinations:
            if destination["destination_id"] == destination_id:
                break
            destination_idx += 1
        else:
            raise InvalidArgumentException(
                "Destination Id {destination_id} not found")

        # Switching between Amazon ES and other services is not supported.
        # For an Amazon ES destination, you can only update to another Amazon
        # ES destination.  Same with HTTP.  Didn't test Splunk.
        if (destination_name == "Elasticsearch" and "Elasticsearch"
                not in destination) or (destination_name == "HttpEndpoint"
                                        and "HttpEndpoint" not in destination):
            raise InvalidArgumentException(
                f"Changing the destination type to or from {destination_name} "
                f"is not supported at this time.")

        # If this is a different type of destination configuration,
        # the existing configuration is reset first.
        if destination_name in destination:
            delivery_stream.destinations[destination_idx][
                destination_name].update(destination_config)
        else:
            delivery_stream.destinations[destination_idx] = {
                "destination_id": destination_id,
                destination_name: destination_config,
            }

        # Once S3 is updated to an ExtendedS3 destination, both remain in
        # the destination.  That means when one is updated, the other needs
        # to be updated as well.  The problem is that they don't have the
        # same fields.
        if destination_name == "ExtendedS3":
            delivery_stream.destinations[destination_idx][
                "S3"] = create_s3_destination_config(destination_config)
        elif destination_name == "S3" and "ExtendedS3" in destination:
            destination["ExtendedS3"] = {
                k: v
                for k, v in destination["S3"].items()
                if k in destination["ExtendedS3"]
            }

        # Increment version number and update the timestamp.
        delivery_stream.version_id = str(
            int(current_delivery_stream_version_id) + 1)
        delivery_stream.last_update_timestamp = datetime.now(
            timezone.utc).isoformat()

        # Unimplemented: processing of the "S3BackupMode" parameter.  Per the
        # documentation:  "You can update a delivery stream to enable Amazon
        # S3 backup if it is disabled.  If backup is enabled, you can't update
        # the delivery stream to disable it."

    def lookup_name_from_arn(self, arn):
        """Given an ARN, return the associated delivery stream name."""
        return self.delivery_streams.get(arn.split("/")[-1])

    def send_log_event(
        self,
        delivery_stream_arn,
        filter_name,
        log_group_name,
        log_stream_name,
        log_events,
    ):  # pylint:  disable=too-many-arguments
        """Send log events to a S3 bucket after encoding and gzipping it."""
        data = {
            "logEvents": log_events,
            "logGroup": log_group_name,
            "logStream": log_stream_name,
            "messageType": "DATA_MESSAGE",
            "owner": get_account_id(),
            "subscriptionFilters": [filter_name],
        }

        output = io.BytesIO()
        with GzipFile(fileobj=output, mode="w") as fhandle:
            fhandle.write(
                json.dumps(data, separators=(",", ":")).encode("utf-8"))
        gzipped_payload = b64encode(output.getvalue()).decode("utf-8")

        delivery_stream = self.lookup_name_from_arn(delivery_stream_arn)
        self.put_s3_records(
            delivery_stream.delivery_stream_name,
            delivery_stream.version_id,
            delivery_stream.destinations[0]["S3"],
            [{
                "Data": gzipped_payload
            }],
        )
Exemplo n.º 9
0
class DirectoryServiceBackend(BaseBackend):
    """Implementation of DirectoryService APIs."""

    def __init__(self, region_name=None):
        self.region_name = region_name
        self.directories = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "ds"
        )

    @staticmethod
    def _validate_create_directory_args(
        name, passwd, size, vpc_settings, description, short_name,
    ):  # pylint: disable=too-many-arguments
        """Raise exception if create_directory() args don't meet constraints.

        The error messages are accumulated before the exception is raised.
        """
        error_tuples = []
        passwd_pattern = (
            r"(?=^.{8,64}$)((?=.*\d)(?=.*[A-Z])(?=.*[a-z])|"
            r"(?=.*\d)(?=.*[^A-Za-z0-9\s])(?=.*[a-z])|"
            r"(?=.*[^A-Za-z0-9\s])(?=.*[A-Z])(?=.*[a-z])|"
            r"(?=.*\d)(?=.*[A-Z])(?=.*[^A-Za-z0-9\s]))^.*"
        )
        if not re.match(passwd_pattern, passwd):
            # Can't have an odd number of backslashes in a literal.
            json_pattern = passwd_pattern.replace("\\", r"\\")
            error_tuples.append(
                (
                    "password",
                    passwd,
                    fr"satisfy regular expression pattern: {json_pattern}",
                )
            )

        if size.lower() not in ["small", "large"]:
            error_tuples.append(
                ("size", size, "satisfy enum value set: [Small, Large]")
            )

        name_pattern = r"^([a-zA-Z0-9]+[\\.-])+([a-zA-Z0-9])+$"
        if not re.match(name_pattern, name):
            error_tuples.append(
                ("name", name, fr"satisfy regular expression pattern: {name_pattern}")
            )

        subnet_id_pattern = r"^(subnet-[0-9a-f]{8}|subnet-[0-9a-f]{17})$"
        for subnet in vpc_settings["SubnetIds"]:
            if not re.match(subnet_id_pattern, subnet):
                error_tuples.append(
                    (
                        "vpcSettings.subnetIds",
                        subnet,
                        fr"satisfy regular expression pattern: {subnet_id_pattern}",
                    )
                )

        if description and len(description) > 128:
            error_tuples.append(
                ("description", description, "have length less than or equal to 128")
            )

        short_name_pattern = r'^[^\/:*?"<>|.]+[^\/:*?"<>|]*$'
        if short_name and not re.match(short_name_pattern, short_name):
            json_pattern = short_name_pattern.replace("\\", r"\\").replace('"', r"\"")
            error_tuples.append(
                (
                    "shortName",
                    short_name,
                    fr"satisfy regular expression pattern: {json_pattern}",
                )
            )

        if error_tuples:
            raise DsValidationException(error_tuples)

    @staticmethod
    def _validate_vpc_setting_values(region, vpc_settings):
        """Raise exception if vpc_settings are invalid.

        If settings are valid, add AvailabilityZones to vpc_settings.
        """
        if len(vpc_settings["SubnetIds"]) != 2:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones."
            )

        from moto.ec2 import ec2_backends  # pylint: disable=import-outside-toplevel

        # Subnet IDs are checked before the VPC ID.  The Subnet IDs must
        # be valid and in different availability zones.
        try:
            subnets = ec2_backends[region].get_all_subnets(
                subnet_ids=vpc_settings["SubnetIds"]
            )
        except InvalidSubnetIdError as exc:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones."
            ) from exc

        regions = [subnet.availability_zone for subnet in subnets]
        if regions[0] == regions[1]:
            raise ClientException(
                "Invalid subnet ID(s). The two subnets must be in "
                "different Availability Zones."
            )

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_settings["VpcId"] not in [x.id for x in vpcs]:
            raise ClientException("Invalid VPC ID.")

        vpc_settings["AvailabilityZones"] = regions

    def create_directory(
        self, region, name, short_name, password, description, size, vpc_settings, tags
    ):  # pylint: disable=too-many-arguments
        """Create a fake Simple Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created"
            )

        # botocore doesn't look for missing vpc_settings, but boto3 does.
        if not vpc_settings:
            raise InvalidParameterException("VpcSettings must be specified.")

        self._validate_create_directory_args(
            name, password, size, vpc_settings, description, short_name,
        )
        self._validate_vpc_setting_values(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)

        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            name,
            password,
            size,
            vpc_settings,
            directory_type="SimpleAD",
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def _validate_directory_id(self, directory_id):
        """Raise an exception if the directory id is invalid or unknown."""
        # Validation of ID takes precedence over a check for its existence.
        id_pattern = r"^d-[0-9a-f]{10}$"
        if not re.match(id_pattern, directory_id):
            raise DsValidationException(
                [
                    (
                        "directoryId",
                        directory_id,
                        fr"satisfy regular expression pattern: {id_pattern}",
                    )
                ]
            )

        if directory_id not in self.directories:
            raise EntityDoesNotExistException(
                f"Directory {directory_id} does not exist"
            )

    def delete_directory(self, directory_id):
        """Delete directory with the matching ID."""
        self._validate_directory_id(directory_id)
        self.tagger.delete_all_tags_for_resource(directory_id)
        self.directories.pop(directory_id)
        return directory_id

    @paginate(pagination_model=PAGINATION_MODEL)
    def describe_directories(
        self, directory_ids=None, next_token=None, limit=0
    ):  # pylint: disable=unused-argument
        """Return info on all directories or directories with matching IDs."""
        for directory_id in directory_ids or self.directories:
            self._validate_directory_id(directory_id)

        directories = list(self.directories.values())
        if directory_ids:
            directories = [x for x in directories if x.directory_id in directory_ids]
        return sorted(directories, key=lambda x: x.launch_time)

    def get_directory_limits(self):
        """Return hard-coded limits for the directories."""
        counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0}
        for directory in self.directories.values():
            if directory.directory_type == "SimpleAD":
                counts["SimpleAD"] += 1
            elif directory.directory_type in ["MicrosoftAD", "SharedMicrosoftAD"]:
                counts["MicrosoftAD"] += 1
            elif directory.directory_type == "ADConnector":
                counts["ConnectedAD"] += 1

        return {
            "CloudOnlyDirectoriesLimit": Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyDirectoriesCurrentCount": counts["SimpleAD"],
            "CloudOnlyDirectoriesLimitReached": counts["SimpleAD"]
            == Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyMicrosoftADLimit": Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "CloudOnlyMicrosoftADCurrentCount": counts["MicrosoftAD"],
            "CloudOnlyMicrosoftADLimitReached": counts["MicrosoftAD"]
            == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "ConnectedDirectoriesLimit": Directory.CONNECTED_DIRECTORIES_LIMIT,
            "ConnectedDirectoriesCurrentCount": counts["ConnectedAD"],
            "ConnectedDirectoriesLimitReached": counts["ConnectedAD"]
            == Directory.CONNECTED_DIRECTORIES_LIMIT,
        }

    def add_tags_to_resource(self, resource_id, tags):
        """Add or overwrite one or more tags for specified directory."""
        self._validate_directory_id(resource_id)
        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise TagLimitExceededException("Tag limit exceeded")
        self.tagger.tag_resource(resource_id, tags)

    def remove_tags_from_resource(self, resource_id, tag_keys):
        """Removes tags from a directory."""
        self._validate_directory_id(resource_id)
        self.tagger.untag_resource_using_names(resource_id, tag_keys)

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(
        self, resource_id, next_token=None, limit=None,
    ):  # pylint: disable=unused-argument
        """List all tags on a directory."""
        self._validate_directory_id(resource_id)
        return self.tagger.list_tags_for_resource(resource_id).get("Tags")
Exemplo n.º 10
0
class KmsBackend(BaseBackend):
    def __init__(self, region_name, account_id=None):
        super().__init__(region_name=region_name, account_id=account_id)
        self.keys = {}
        self.key_to_aliases = defaultdict(set)
        self.tagger = TaggingService(key_name="TagKey", value_name="TagValue")

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """Default VPC endpoint service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "kms")

    def _generate_default_keys(self, alias_name):
        """Creates default kms keys"""
        if alias_name in RESERVED_ALIASES:
            key = self.create_key(
                None,
                "ENCRYPT_DECRYPT",
                "SYMMETRIC_DEFAULT",
                "Default key",
                None,
                self.region_name,
            )
            self.add_alias(key.id, alias_name)
            return key.id

    def create_key(self,
                   policy,
                   key_usage,
                   key_spec,
                   description,
                   tags,
                   region,
                   multi_region=False):
        key = Key(policy, key_usage, key_spec, description, region,
                  multi_region)
        self.keys[key.id] = key
        if tags is not None and len(tags) > 0:
            self.tag_resource(key.id, tags)
        return key

    # https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html#mrk-sync-properties
    # In AWS replicas of a key only share some properties with the original key. Some of those properties get updated
    # in all replicas automatically if those properties change in the original key. Also, such properties can not be
    # changed for replicas directly.
    #
    # In our implementation with just create a copy of all the properties once without any protection from change,
    # as the exact implementation is currently infeasible.
    def replicate_key(self, key_id, replica_region):
        # Using copy() instead of deepcopy(), as the latter results in exception:
        #    TypeError: cannot pickle '_cffi_backend.FFI' object
        # Since we only update top level properties, copy() should suffice.
        replica_key = copy(self.keys[key_id])
        replica_key.region = replica_region
        to_region_backend = kms_backends[replica_region]
        to_region_backend.keys[replica_key.id] = replica_key

    def update_key_description(self, key_id, description):
        key = self.keys[self.get_key_id(key_id)]
        key.description = description

    def delete_key(self, key_id):
        if key_id in self.keys:
            if key_id in self.key_to_aliases:
                self.key_to_aliases.pop(key_id)
            self.tagger.delete_all_tags_for_resource(key_id)

            return self.keys.pop(key_id)

    def describe_key(self, key_id) -> Key:
        # allow the different methods (alias, ARN :key/, keyId, ARN alias) to
        # describe key not just KeyId
        key_id = self.get_key_id(key_id)
        if r"alias/" in str(key_id).lower():
            key_id = self.get_key_id_from_alias(key_id)
        return self.keys[self.get_key_id(key_id)]

    def list_keys(self):
        return self.keys.values()

    @staticmethod
    def get_key_id(key_id):
        # Allow use of ARN as well as pure KeyId
        if key_id.startswith("arn:") and ":key/" in key_id:
            return key_id.split(":key/")[1]

        return key_id

    @staticmethod
    def get_alias_name(alias_name):
        # Allow use of ARN as well as alias name
        if alias_name.startswith("arn:") and ":alias/" in alias_name:
            return "alias/" + alias_name.split(":alias/")[1]

        return alias_name

    def any_id_to_key_id(self, key_id):
        """Go from any valid key ID to the raw key ID.

        Acceptable inputs:
        - raw key ID
        - key ARN
        - alias name
        - alias ARN
        """
        key_id = self.get_alias_name(key_id)
        key_id = self.get_key_id(key_id)
        if key_id.startswith("alias/"):
            key_id = self.get_key_id_from_alias(key_id)
        return key_id

    def alias_exists(self, alias_name):
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                return True

        return False

    def add_alias(self, target_key_id, alias_name):
        self.key_to_aliases[target_key_id].add(alias_name)

    def delete_alias(self, alias_name):
        """Delete the alias."""
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                aliases.remove(alias_name)

    def get_all_aliases(self):
        return self.key_to_aliases

    def get_key_id_from_alias(self, alias_name):
        for key_id, aliases in dict(self.key_to_aliases).items():
            if alias_name in ",".join(aliases):
                return key_id
        if alias_name in RESERVED_ALIASES:
            key_id = self._generate_default_keys(alias_name)
            return key_id
        return None

    def enable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = True

    def disable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = False

    def get_key_rotation_status(self, key_id):
        return self.keys[self.get_key_id(key_id)].key_rotation_status

    def put_key_policy(self, key_id, policy):
        self.keys[self.get_key_id(key_id)].policy = policy

    def get_key_policy(self, key_id):
        return self.keys[self.get_key_id(key_id)].policy

    def disable_key(self, key_id):
        self.keys[key_id].enabled = False
        self.keys[key_id].key_state = "Disabled"

    def enable_key(self, key_id):
        self.keys[key_id].enabled = True
        self.keys[key_id].key_state = "Enabled"

    def cancel_key_deletion(self, key_id):
        self.keys[key_id].key_state = "Disabled"
        self.keys[key_id].deletion_date = None

    def schedule_key_deletion(self, key_id, pending_window_in_days):
        if 7 <= pending_window_in_days <= 30:
            self.keys[key_id].enabled = False
            self.keys[key_id].key_state = "PendingDeletion"
            self.keys[key_id].deletion_date = datetime.now() + timedelta(
                days=pending_window_in_days)
            return unix_time(self.keys[key_id].deletion_date)

    def encrypt(self, key_id, plaintext, encryption_context):
        key_id = self.any_id_to_key_id(key_id)

        ciphertext_blob = encrypt(
            master_keys=self.keys,
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return ciphertext_blob, arn

    def decrypt(self, ciphertext_blob, encryption_context):
        plaintext, key_id = decrypt(
            master_keys=self.keys,
            ciphertext_blob=ciphertext_blob,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return plaintext, arn

    def re_encrypt(
        self,
        ciphertext_blob,
        source_encryption_context,
        destination_key_id,
        destination_encryption_context,
    ):
        destination_key_id = self.any_id_to_key_id(destination_key_id)

        plaintext, decrypting_arn = self.decrypt(
            ciphertext_blob=ciphertext_blob,
            encryption_context=source_encryption_context,
        )
        new_ciphertext_blob, encrypting_arn = self.encrypt(
            key_id=destination_key_id,
            plaintext=plaintext,
            encryption_context=destination_encryption_context,
        )
        return new_ciphertext_blob, decrypting_arn, encrypting_arn

    def generate_data_key(self, key_id, encryption_context, number_of_bytes,
                          key_spec):
        key_id = self.any_id_to_key_id(key_id)

        if key_spec:
            # Note: Actual validation of key_spec is done in kms.responses
            if key_spec == "AES_128":
                plaintext_len = 16
            else:
                plaintext_len = 32
        else:
            plaintext_len = number_of_bytes

        plaintext = os.urandom(plaintext_len)

        ciphertext_blob, arn = self.encrypt(
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context)

        return plaintext, ciphertext_blob, arn

    def list_resource_tags(self, key_id_or_arn):
        key_id = self.get_key_id(key_id_or_arn)
        if key_id in self.keys:
            return self.tagger.list_tags_for_resource(key_id)
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def tag_resource(self, key_id_or_arn, tags):
        key_id = self.get_key_id(key_id_or_arn)
        if key_id in self.keys:
            self.tagger.tag_resource(key_id, tags)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def untag_resource(self, key_id_or_arn, tag_names):
        key_id = self.get_key_id(key_id_or_arn)
        if key_id in self.keys:
            self.tagger.untag_resource_using_names(key_id, tag_names)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def create_grant(
        self,
        key_id,
        grantee_principal,
        operations,
        name,
        constraints,
        retiring_principal,
    ):
        key = self.describe_key(key_id)
        grant = key.add_grant(
            name,
            grantee_principal,
            operations,
            constraints=constraints,
            retiring_principal=retiring_principal,
        )
        return grant.id, grant.token

    def list_grants(self, key_id, grant_id) -> [Grant]:
        key = self.describe_key(key_id)
        return key.list_grants(grant_id)

    def list_retirable_grants(self, retiring_principal):
        grants = []
        for key in self.keys.values():
            grants.extend(key.list_retirable_grants(retiring_principal))
        return grants

    def revoke_grant(self, key_id, grant_id) -> None:
        key = self.describe_key(key_id)
        key.revoke_grant(grant_id)

    def retire_grant(self, key_id, grant_id, grant_token) -> None:
        if grant_token:
            for key in self.keys.values():
                key.retire_grant_by_token(grant_token)
        else:
            key = self.describe_key(key_id)
            key.retire_grant(grant_id)

    def __ensure_valid_sign_and_verify_key(self, key: Key):
        if key.key_usage != "SIGN_VERIFY":
            raise ValidationException((
                "1 validation error detected: Value '{key_id}' at 'KeyId' failed "
                "to satisfy constraint: Member must point to a key with usage: 'SIGN_VERIFY'"
            ).format(key_id=key.id))

    def __ensure_valid_signing_augorithm(self, key: Key, signing_algorithm):
        if signing_algorithm not in key.signing_algorithms:
            raise ValidationException((
                "1 validation error detected: Value '{signing_algorithm}' at 'SigningAlgorithm' failed "
                "to satisfy constraint: Member must satisfy enum value set: "
                "{valid_sign_algorithms}").format(
                    signing_algorithm=signing_algorithm,
                    valid_sign_algorithms=key.signing_algorithms,
                ))

    def sign(self, key_id, message, signing_algorithm):
        """Sign message using generated private key.

        - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256

        - grant_tokens are not implemented
        """
        key = self.describe_key(key_id)

        self.__ensure_valid_sign_and_verify_key(key)
        self.__ensure_valid_signing_augorithm(key, signing_algorithm)

        # TODO: support more than one hardcoded algorithm based on KeySpec
        signature = key.private_key.sign(
            message,
            padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                        salt_length=padding.PSS.MAX_LENGTH),
            hashes.SHA256(),
        )

        return key.arn, signature, signing_algorithm

    def verify(self, key_id, message, signature, signing_algorithm):
        """Verify message using public key from generated private key.

        - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256

        - grant_tokens are not implemented
        """
        key = self.describe_key(key_id)

        self.__ensure_valid_sign_and_verify_key(key)
        self.__ensure_valid_signing_augorithm(key, signing_algorithm)

        if signing_algorithm not in key.signing_algorithms:
            raise ValidationException((
                "1 validation error detected: Value '{signing_algorithm}' at 'SigningAlgorithm' failed "
                "to satisfy constraint: Member must satisfy enum value set: "
                "{valid_sign_algorithms}").format(
                    signing_algorithm=signing_algorithm,
                    valid_sign_algorithms=key.signing_algorithms,
                ))

        public_key = key.private_key.public_key()

        try:
            # TODO: support more than one hardcoded algorithm based on KeySpec
            public_key.verify(
                signature,
                message,
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH,
                ),
                hashes.SHA256(),
            )
            return key.arn, True, signing_algorithm
        except InvalidSignature:
            return key.arn, False, signing_algorithm
Exemplo n.º 11
0
class ECRBackend(BaseBackend):
    def __init__(self, region_name):
        self.region_name = region_name
        self.repositories: Dict[str, Repository] = {}
        self.tagger = TaggingService(tagName="tags")

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def _get_repository(self, name, registry_id=None) -> Repository:
        repo = self.repositories.get(name)
        reg_id = registry_id or DEFAULT_REGISTRY_ID

        if not repo or repo.registry_id != reg_id:
            raise RepositoryNotFoundException(name, reg_id)
        return repo

    @staticmethod
    def _parse_resource_arn(resource_arn) -> EcrRepositoryArn:
        match = re.match(ECR_REPOSITORY_ARN_PATTERN, resource_arn)
        if not match:
            raise InvalidParameterException(
                "Invalid parameter at 'resourceArn' failed to satisfy constraint: "
                "'Invalid ARN'")
        return EcrRepositoryArn(**match.groupdict())

    def describe_repositories(self, registry_id=None, repository_names=None):
        """
        maxResults and nextToken not implemented
        """
        if repository_names:
            for repository_name in repository_names:
                if repository_name not in self.repositories:
                    raise RepositoryNotFoundException(
                        repository_name, registry_id or DEFAULT_REGISTRY_ID)

        repositories = []
        for repository in self.repositories.values():
            # If a registry_id was supplied, ensure this repository matches
            if registry_id:
                if repository.registry_id != registry_id:
                    continue
            # If a list of repository names was supplied, esure this repository
            # is in that list
            if repository_names:
                if repository.name not in repository_names:
                    continue
            repositories.append(repository.response_object)
        return repositories

    def create_repository(
        self,
        repository_name,
        encryption_config,
        image_scan_config,
        image_tag_mutablility,
        tags,
    ):
        if self.repositories.get(repository_name):
            raise RepositoryAlreadyExistsException(repository_name,
                                                   DEFAULT_REGISTRY_ID)

        repository = Repository(
            region_name=self.region_name,
            repository_name=repository_name,
            encryption_config=encryption_config,
            image_scan_config=image_scan_config,
            image_tag_mutablility=image_tag_mutablility,
        )
        self.repositories[repository_name] = repository
        self.tagger.tag_resource(repository.arn, tags)

        return repository

    def delete_repository(self, repository_name, registry_id=None):
        repo = self._get_repository(repository_name, registry_id)

        if repo.images:
            raise RepositoryNotEmptyException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        self.tagger.delete_all_tags_for_resource(repo.arn)
        return self.repositories.pop(repository_name)

    def list_images(self, repository_name, registry_id=None):
        """
        maxResults and filtering not implemented
        """
        repository = None
        found = False
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
            if registry_id:
                if repository.registry_id == registry_id:
                    found = True
            else:
                found = True

        if not found:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        images = []
        for image in repository.images:
            images.append(image)
        return images

    def describe_images(self,
                        repository_name,
                        registry_id=None,
                        image_ids=None):

        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if image_ids:
            response = set()
            for image_id in image_ids:
                found = False
                for image in repository.images:
                    if ("imageDigest" in image_id and image.get_image_digest()
                            == image_id["imageDigest"]) or (
                                "imageTag" in image_id
                                and image_id["imageTag"] in image.image_tags):
                        found = True
                        response.add(image)
                if not found:
                    image_id_representation = "{imageDigest:'%s', imageTag:'%s'}" % (
                        image_id.get("imageDigest", "null"),
                        image_id.get("imageTag", "null"),
                    )
                    raise ImageNotFoundException(
                        image_id=image_id_representation,
                        repository_name=repository_name,
                        registry_id=registry_id or DEFAULT_REGISTRY_ID,
                    )

        else:
            response = []
            for image in repository.images:
                response.append(image)

        return response

    def put_image(self, repository_name, image_manifest, image_tag):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise Exception("{0} is not a repository".format(repository_name))

        existing_images = list(
            filter(
                lambda x: x.response_object["imageManifest"] == image_manifest,
                repository.images,
            ))
        if not existing_images:
            # this image is not in ECR yet
            image = Image(image_tag, image_manifest, repository_name)
            repository.images.append(image)
            return image
        else:
            # update existing image
            existing_images[0].update_tag(image_tag)
            return existing_images[0]

    def batch_get_image(
        self,
        repository_name,
        registry_id=None,
        image_ids=None,
        accepted_media_types=None,
    ):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"images": [], "failures": []}

        for image_id in image_ids:
            found = False
            for image in repository.images:
                if ("imageDigest" in image_id
                        and image.get_image_digest() == image_id["imageDigest"]
                    ) or ("imageTag" in image_id
                          and image.image_tag == image_id["imageTag"]):
                    found = True
                    response["images"].append(image.response_batch_get_image)

        if not found:
            response["failures"].append({
                "imageId": {
                    "imageTag": image_id.get("imageTag", "null")
                },
                "failureCode":
                "ImageNotFound",
                "failureReason":
                "Requested image not found",
            })

        return response

    def batch_delete_image(self,
                           repository_name,
                           registry_id=None,
                           image_ids=None):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"imageIds": [], "failures": []}

        for image_id in image_ids:
            image_found = False

            # Is request missing both digest and tag?
            if "imageDigest" not in image_id and "imageTag" not in image_id:
                response["failures"].append({
                    "imageId": {},
                    "failureCode":
                    "MissingDigestAndTag",
                    "failureReason":
                    "Invalid request parameters: both tag and digest cannot be null",
                })
                continue

            # If we have a digest, is it valid?
            if "imageDigest" in image_id:
                pattern = re.compile(r"^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}")
                if not pattern.match(image_id.get("imageDigest")):
                    response["failures"].append({
                        "imageId": {
                            "imageDigest": image_id.get("imageDigest", "null")
                        },
                        "failureCode":
                        "InvalidImageDigest",
                        "failureReason":
                        "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'",
                    })
                    continue

            for num, image in enumerate(repository.images):

                # Search by matching both digest and tag
                if "imageDigest" in image_id and "imageTag" in image_id:
                    if (image_id["imageDigest"] == image.get_image_digest()
                            and image_id["imageTag"] in image.image_tags):
                        image_found = True
                        for image_tag in reversed(image.image_tags):
                            repository.images[num].image_tag = image_tag
                            response["imageIds"].append(
                                image.response_batch_delete_image)
                            repository.images[num].remove_tag(image_tag)
                        del repository.images[num]

                # Search by matching digest
                elif ("imageDigest" in image_id
                      and image.get_image_digest() == image_id["imageDigest"]):
                    image_found = True
                    for image_tag in reversed(image.image_tags):
                        repository.images[num].image_tag = image_tag
                        response["imageIds"].append(
                            image.response_batch_delete_image)
                        repository.images[num].remove_tag(image_tag)
                    del repository.images[num]

                # Search by matching tag
                elif ("imageTag" in image_id
                      and image_id["imageTag"] in image.image_tags):
                    image_found = True
                    repository.images[num].image_tag = image_id["imageTag"]
                    response["imageIds"].append(
                        image.response_batch_delete_image)
                    if len(image.image_tags) > 1:
                        repository.images[num].remove_tag(image_id["imageTag"])
                    else:
                        repository.images.remove(image)

                if not image_found:
                    failure_response = {
                        "imageId": {},
                        "failureCode": "ImageNotFound",
                        "failureReason": "Requested image not found",
                    }

                    if "imageDigest" in image_id:
                        failure_response["imageId"][
                            "imageDigest"] = image_id.get(
                                "imageDigest", "null")

                    if "imageTag" in image_id:
                        failure_response["imageId"]["imageTag"] = image_id.get(
                            "imageTag", "null")

                    response["failures"].append(failure_response)

        return response

    def list_tags_for_resource(self, arn):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)

        return self.tagger.list_tags_for_resource(repo.arn)

    def tag_resource(self, arn, tags):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.tag_resource(repo.arn, tags)

        return {}

    def untag_resource(self, arn, tag_keys):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.untag_resource_using_names(repo.arn, tag_keys)

        return {}

    def put_image_tag_mutability(self, registry_id, repository_name,
                                 image_tag_mutability):
        if image_tag_mutability not in ["IMMUTABLE", "MUTABLE"]:
            raise InvalidParameterException(
                "Invalid parameter at 'imageTagMutability' failed to satisfy constraint: "
                "'Member must satisfy enum value set: [IMMUTABLE, MUTABLE]'")

        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_tag_mutability=image_tag_mutability)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageTagMutability": repo.image_tag_mutability,
        }

    def put_image_scanning_configuration(self, registry_id, repository_name,
                                         image_scan_config):
        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_scan_config=image_scan_config)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageScanningConfiguration": repo.image_scanning_configuration,
        }

    def set_repository_policy(self, registry_id, repository_name, policy_text):
        repo = self._get_repository(repository_name, registry_id)

        try:
            iam_policy_document_validator = IAMPolicyDocumentValidator(
                policy_text)
            # the repository policy can be defined without a resource field
            iam_policy_document_validator._validate_resource_exist = lambda: None
            # the repository policy can have the old version 2008-10-17
            iam_policy_document_validator._validate_version = lambda: None
            iam_policy_document_validator.validate()
        except MalformedPolicyDocument:
            raise InvalidParameterException(
                "Invalid parameter at 'PolicyText' failed to satisfy constraint: "
                "'Invalid repository policy provided'")

        repo.policy = policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def get_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def delete_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.policy

        if not policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        repo.policy = None

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": policy,
        }

    def put_lifecycle_policy(self, registry_id, repository_name,
                             lifecycle_policy_text):
        repo = self._get_repository(repository_name, registry_id)

        validator = EcrLifecyclePolicyValidator(lifecycle_policy_text)
        validator.validate()

        repo.lifecycle_policy = lifecycle_policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "lifecyclePolicyText": repo.lifecycle_policy,
        }

    def get_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.lifecycle_policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            repo.lifecycle_policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }

    def delete_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.lifecycle_policy

        if not policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        repo.lifecycle_policy = None

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }