Пример #1
0
class PinpointBackend(BaseBackend):
    """Implementation of Pinpoint APIs."""
    def __init__(self, region_name, account_id):
        super().__init__(region_name, account_id)
        self.apps = {}
        self.tagger = TaggingService()

    def create_app(self, name, tags):
        app = App(name)
        self.apps[app.application_id] = app
        tags = self.tagger.convert_dict_to_tags_input(tags)
        self.tagger.tag_resource(app.arn, tags)
        return app

    def delete_app(self, application_id):
        self.get_app(application_id)
        return self.apps.pop(application_id)

    def get_app(self, application_id):
        if application_id not in self.apps:
            raise ApplicationNotFound()
        return self.apps[application_id]

    def get_apps(self):
        """
        Pagination is not yet implemented
        """
        return self.apps.values()

    def update_application_settings(self, application_id, settings):
        app = self.get_app(application_id)
        return app.update_settings(settings)

    def get_application_settings(self, application_id):
        app = self.get_app(application_id)
        return app.get_settings()

    def list_tags_for_resource(self, resource_arn):
        tags = self.tagger.get_tag_dict_for_resource(resource_arn)
        return {"tags": tags}

    def tag_resource(self, resource_arn, tags):
        tags = TaggingService.convert_dict_to_tags_input(tags)
        self.tagger.tag_resource(resource_arn, tags)

    def untag_resource(self, resource_arn, tag_keys):
        self.tagger.untag_resource_using_names(resource_arn, tag_keys)
        return

    def put_event_stream(self, application_id, stream_arn, role_arn):
        app = self.get_app(application_id)
        return app.put_event_stream(stream_arn, role_arn)

    def get_event_stream(self, application_id):
        app = self.get_app(application_id)
        return app.get_event_stream()

    def delete_event_stream(self, application_id):
        app = self.get_app(application_id)
        return app.delete_event_stream()
Пример #2
0
def test_create_tag_without_value():
    svc = TaggingService()
    tags = [{"Key": "key_key"}]
    svc.tag_resource("arn", tags)
    actual = svc.list_tags_for_resource("arn")
    expected = {"Tags": [{"Key": "key_key", "Value": None}]}

    expected.should.be.equal(actual)
Пример #3
0
def test_create_tag():
    svc = TaggingService("TheTags", "TagKey", "TagValue")
    tags = [{"TagKey": "key_key", "TagValue": "value_value"}]
    svc.tag_resource("arn", tags)
    actual = svc.list_tags_for_resource("arn")
    expected = {"TheTags": [{"TagKey": "key_key", "TagValue": "value_value"}]}

    expected.should.be.equal(actual)
Пример #4
0
def test_delete_tag_using_tags():
    svc = TaggingService()
    tags = [{"Key": "key_key", "Value": "value_value"}]
    svc.tag_resource("arn", tags)
    svc.untag_resource_using_tags("arn", tags)
    result = svc.list_tags_for_resource("arn")

    {"Tags": []}.should.be.equal(result)
Пример #5
0
def test_delete_all_tags_for_resource():
    svc = TaggingService()
    tags = [{"Key": "key_key", "Value": "value_value"}]
    tags2 = [{"Key": "key_key2", "Value": "value_value2"}]
    svc.tag_resource("arn", tags)
    svc.tag_resource("arn", tags2)
    svc.delete_all_tags_for_resource("arn")
    result = svc.list_tags_for_resource("arn")

    {"Tags": []}.should.be.equal(result)
Пример #6
0
def test_copy_non_existing_arn():
    svc = TaggingService()
    tags = [{
        "Key": "key1",
        "Value": "value1"
    }, {
        "Key": "key2",
        "Value": "value2"
    }]
    svc.tag_resource("new_arn", tags)
    #
    svc.copy_tags("non_existing_arn", "new_arn")
    # Copying from a non-existing ARN should a NOOP
    # Assert the old tags still exist
    actual = sorted(svc.list_tags_for_resource("new_arn")["Tags"],
                    key=lambda t: t["Key"])
    actual.should.equal(tags)
Пример #7
0
def test_copy_existing_arn():
    svc = TaggingService()
    tags_old_arn = [{"Key": "key1", "Value": "value1"}]
    tags_new_arn = [{"Key": "key2", "Value": "value2"}]
    svc.tag_resource("old_arn", tags_old_arn)
    svc.tag_resource("new_arn", tags_new_arn)
    #
    svc.copy_tags("old_arn", "new_arn")
    # Assert the old tags still exist
    actual = sorted(svc.list_tags_for_resource("new_arn")["Tags"],
                    key=lambda t: t["Key"])
    actual.should.equal([{
        "Key": "key1",
        "Value": "value1"
    }, {
        "Key": "key2",
        "Value": "value2"
    }])
Пример #8
0
class FirehoseBackend(BaseBackend):
    """Implementation of Firehose APIs."""
    def __init__(self, region_name=None):
        self.region_name = region_name
        self.delivery_streams = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initializes all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """Default VPC endpoint service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region,
            zones,
            "firehose",
            special_service_name="kinesis-firehose")

    def create_delivery_stream(
        self,
        region,
        delivery_stream_name,
        delivery_stream_type,
        kinesis_stream_source_configuration,
        delivery_stream_encryption_configuration_input,
        s3_destination_configuration,
        extended_s3_destination_configuration,
        redshift_destination_configuration,
        elasticsearch_destination_configuration,
        splunk_destination_configuration,
        http_endpoint_destination_configuration,
        tags,
    ):  # pylint: disable=too-many-arguments,too-many-locals,unused-argument
        """Create a Kinesis Data Firehose delivery stream."""
        (destination_name,
         destination_config) = find_destination_config_in_args(locals())

        if delivery_stream_name in self.delivery_streams:
            raise ResourceInUseException(
                f"Firehose {delivery_stream_name} under accountId {get_account_id()} "
                f"already exists")

        if len(self.delivery_streams) == DeliveryStream.MAX_STREAMS_PER_REGION:
            raise LimitExceededException(
                f"You have already consumed your firehose quota of "
                f"{DeliveryStream.MAX_STREAMS_PER_REGION} hoses. Firehose "
                f"names: {list(self.delivery_streams.keys())}")

        # Rule out situations that are not yet implemented.
        if delivery_stream_encryption_configuration_input:
            warnings.warn(
                "A delivery stream with server-side encryption enabled is not "
                "yet implemented")

        if destination_name == "Splunk":
            warnings.warn(
                "A Splunk destination delivery stream is not yet implemented")

        if (kinesis_stream_source_configuration
                and delivery_stream_type != "KinesisStreamAsSource"):
            raise InvalidArgumentException(
                "KinesisSourceStreamConfig is only applicable for "
                "KinesisStreamAsSource stream type")

        # Validate the tags before proceeding.
        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)

        if tags and len(tags) > MAX_TAGS_PER_DELIVERY_STREAM:
            raise ValidationException(
                f"1 validation error detected: Value '{tags}' at 'tags' "
                f"failed to satisify contstraint: Member must have length "
                f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}")

        # Create a DeliveryStream instance that will be stored and indexed
        # by delivery stream name.  This instance will update the state and
        # create the ARN.
        delivery_stream = DeliveryStream(
            region,
            delivery_stream_name,
            delivery_stream_type,
            kinesis_stream_source_configuration,
            destination_name,
            destination_config,
        )
        self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags
                                 or [])

        self.delivery_streams[delivery_stream_name] = delivery_stream
        return self.delivery_streams[delivery_stream_name].delivery_stream_arn

    def delete_delivery_stream(self,
                               delivery_stream_name,
                               allow_force_delete=False):  # pylint: disable=unused-argument
        """Delete a delivery stream and its data.

        AllowForceDelete option is ignored as we only superficially
        apply state.
        """
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        self.tagger.delete_all_tags_for_resource(
            delivery_stream.delivery_stream_arn)

        delivery_stream.delivery_stream_status = "DELETING"
        self.delivery_streams.pop(delivery_stream_name)

    def describe_delivery_stream(self, delivery_stream_name, limit,
                                 exclusive_start_destination_id):  # pylint: disable=unused-argument
        """Return description of specified delivery stream and its status.

        Note:  the 'limit' and 'exclusive_start_destination_id' parameters
        are not currently processed/implemented.
        """
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        result = {"DeliveryStreamDescription": {"HasMoreDestinations": False}}
        for attribute, attribute_value in vars(delivery_stream).items():
            if not attribute_value:
                continue

            # Convert from attribute's snake case to camel case for outgoing
            # JSON.
            name = "".join([x.capitalize() for x in attribute.split("_")])

            # Fooey ... always an exception to the rule:
            if name == "DeliveryStreamArn":
                name = "DeliveryStreamARN"

            if name != "Destinations":
                if name == "Source":
                    result["DeliveryStreamDescription"][name] = {
                        "KinesisStreamSourceDescription": attribute_value
                    }
                else:
                    result["DeliveryStreamDescription"][name] = attribute_value
                continue

            result["DeliveryStreamDescription"]["Destinations"] = []
            for destination in attribute_value:
                description = {}
                for key, value in destination.items():
                    if key == "destination_id":
                        description["DestinationId"] = value
                    else:
                        description[f"{key}DestinationDescription"] = value

                result["DeliveryStreamDescription"]["Destinations"].append(
                    description)

        return result

    def list_delivery_streams(self, limit, delivery_stream_type,
                              exclusive_start_delivery_stream_name):
        """Return list of delivery streams in alphabetic order of names."""
        result = {"DeliveryStreamNames": [], "HasMoreDeliveryStreams": False}
        if not self.delivery_streams:
            return result

        # If delivery_stream_type is specified, filter out any stream that's
        # not of that type.
        stream_list = self.delivery_streams.keys()
        if delivery_stream_type:
            stream_list = [
                x for x in stream_list
                if self.delivery_streams[x].delivery_stream_type ==
                delivery_stream_type
            ]

        # The list is sorted alphabetically, not alphanumerically.
        sorted_list = sorted(stream_list)

        # Determine the limit or number of names to return in the list.
        limit = limit or DeliveryStream.MAX_STREAMS_PER_REGION

        # If a starting delivery stream name is given, find the index into
        # the sorted list, then add one to get the name following it.  If the
        # exclusive_start_delivery_stream_name doesn't exist, it's ignored.
        start = 0
        if exclusive_start_delivery_stream_name:
            if self.delivery_streams.get(exclusive_start_delivery_stream_name):
                start = sorted_list.index(
                    exclusive_start_delivery_stream_name) + 1

        result["DeliveryStreamNames"] = sorted_list[start:start + limit]
        if len(sorted_list) > (start + limit):
            result["HasMoreDeliveryStreams"] = True
        return result

    def list_tags_for_delivery_stream(self, delivery_stream_name,
                                      exclusive_start_tag_key, limit):
        """Return list of tags."""
        result = {"Tags": [], "HasMoreTags": False}
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        tags = self.tagger.list_tags_for_resource(
            delivery_stream.delivery_stream_arn)["Tags"]
        keys = self.tagger.extract_tag_names(tags)

        # If a starting tag is given and can be found, find the index into
        # tags, then add one to get the tag following it.
        start = 0
        if exclusive_start_tag_key:
            if exclusive_start_tag_key in keys:
                start = keys.index(exclusive_start_tag_key) + 1

        limit = limit or MAX_TAGS_PER_DELIVERY_STREAM
        result["Tags"] = tags[start:start + limit]
        if len(tags) > (start + limit):
            result["HasMoreTags"] = True
        return result

    def put_record(self, delivery_stream_name, record):
        """Write a single data record into a Kinesis Data firehose stream."""
        result = self.put_record_batch(delivery_stream_name, [record])
        return {
            "RecordId": result["RequestResponses"][0]["RecordId"],
            "Encrypted": False,
        }

    @staticmethod
    def put_http_records(http_destination, records):
        """Put records to a HTTP destination."""
        # Mostly copied from localstack
        url = http_destination["EndpointConfiguration"]["Url"]
        headers = {"Content-Type": "application/json"}
        record_to_send = {
            "requestId": str(uuid4()),
            "timestamp": int(time()),
            "records": [{
                "data": record["Data"]
            } for record in records],
        }
        try:
            requests.post(url, json=record_to_send, headers=headers)
        except Exception as exc:
            # This could be better ...
            raise RuntimeError(
                "Firehose PutRecord(Batch) to HTTP destination failed"
            ) from exc
        return [{"RecordId": str(uuid4())} for _ in range(len(records))]

    @staticmethod
    def _format_s3_object_path(delivery_stream_name, version_id, prefix):
        """Return a S3 object path in the expected format."""
        # Taken from LocalStack's firehose logic, with minor changes.
        # See https://docs.aws.amazon.com/firehose/latest/dev/basic-deliver.html#s3-object-name
        # Path prefix pattern: myApp/YYYY/MM/DD/HH/
        # Object name pattern:
        # DeliveryStreamName-DeliveryStreamVersion-YYYY-MM-DD-HH-MM-SS-RandomString
        prefix = f"{prefix}{'' if prefix.endswith('/') else '/'}"
        now = datetime.utcnow()
        return (f"{prefix}{now.strftime('%Y/%m/%d/%H')}/"
                f"{delivery_stream_name}-{version_id}-"
                f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{str(uuid4())}")

    def put_s3_records(self, delivery_stream_name, version_id, s3_destination,
                       records):
        """Put records to a ExtendedS3 or S3 destination."""
        # Taken from LocalStack's firehose logic, with minor changes.
        bucket_name = s3_destination["BucketARN"].split(":")[-1]
        prefix = s3_destination.get("Prefix", "")
        object_path = self._format_s3_object_path(delivery_stream_name,
                                                  version_id, prefix)

        batched_data = b"".join([b64decode(r["Data"]) for r in records])
        try:
            s3_backend.put_object(bucket_name, object_path, batched_data)
        except Exception as exc:
            # This could be better ...
            raise RuntimeError(
                "Firehose PutRecord(Batch to S3 destination failed") from exc
        return [{"RecordId": str(uuid4())} for _ in range(len(records))]

    def put_record_batch(self, delivery_stream_name, records):
        """Write multiple data records into a Kinesis Data firehose stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        request_responses = []
        for destination in delivery_stream.destinations:
            if "ExtendedS3" in destination:
                # ExtendedS3 will be handled like S3,but in the future
                # this will probably need to be revisited.  This destination
                # must be listed before S3 otherwise both destinations will
                # be processed instead of just ExtendedS3.
                request_responses = self.put_s3_records(
                    delivery_stream_name,
                    delivery_stream.version_id,
                    destination["ExtendedS3"],
                    records,
                )
            elif "S3" in destination:
                request_responses = self.put_s3_records(
                    delivery_stream_name,
                    delivery_stream.version_id,
                    destination["S3"],
                    records,
                )
            elif "HttpEndpoint" in destination:
                request_responses = self.put_http_records(
                    destination["HttpEndpoint"], records)
            elif "Elasticsearch" in destination or "Redshift" in destination:
                # This isn't implmented as these services aren't implemented,
                # so ignore the data, but return a "proper" response.
                request_responses = [{
                    "RecordId": str(uuid4())
                } for _ in range(len(records))]

        return {
            "FailedPutCount": 0,
            "Encrypted": False,
            "RequestResponses": request_responses,
        }

    def tag_delivery_stream(self, delivery_stream_name, tags):
        """Add/update tags for specified delivery stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        if len(tags) > MAX_TAGS_PER_DELIVERY_STREAM:
            raise ValidationException(
                f"1 validation error detected: Value '{tags}' at 'tags' "
                f"failed to satisify contstraint: Member must have length "
                f"less than or equal to {MAX_TAGS_PER_DELIVERY_STREAM}")

        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)

        self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags)

    def untag_delivery_stream(self, delivery_stream_name, tag_keys):
        """Removes tags from specified delivery stream."""
        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under account {get_account_id()} "
                f"not found.")

        # If a tag key doesn't exist for the stream, boto3 ignores it.
        self.tagger.untag_resource_using_names(
            delivery_stream.delivery_stream_arn, tag_keys)

    def update_destination(
        self,
        delivery_stream_name,
        current_delivery_stream_version_id,
        destination_id,
        s3_destination_update,
        extended_s3_destination_update,
        s3_backup_mode,
        redshift_destination_update,
        elasticsearch_destination_update,
        splunk_destination_update,
        http_endpoint_destination_update,
    ):  # pylint: disable=unused-argument,too-many-arguments,too-many-locals
        """Updates specified destination of specified delivery stream."""
        (destination_name,
         destination_config) = find_destination_config_in_args(locals())

        delivery_stream = self.delivery_streams.get(delivery_stream_name)
        if not delivery_stream:
            raise ResourceNotFoundException(
                f"Firehose {delivery_stream_name} under accountId "
                f"{get_account_id()} not found.")

        if destination_name == "Splunk":
            warnings.warn(
                "A Splunk destination delivery stream is not yet implemented")

        if delivery_stream.version_id != current_delivery_stream_version_id:
            raise ConcurrentModificationException(
                f"Cannot update firehose: {delivery_stream_name} since the "
                f"current version id: {delivery_stream.version_id} and "
                f"specified version id: {current_delivery_stream_version_id} "
                f"do not match")

        destination = {}
        destination_idx = 0
        for destination in delivery_stream.destinations:
            if destination["destination_id"] == destination_id:
                break
            destination_idx += 1
        else:
            raise InvalidArgumentException(
                "Destination Id {destination_id} not found")

        # Switching between Amazon ES and other services is not supported.
        # For an Amazon ES destination, you can only update to another Amazon
        # ES destination.  Same with HTTP.  Didn't test Splunk.
        if (destination_name == "Elasticsearch" and "Elasticsearch"
                not in destination) or (destination_name == "HttpEndpoint"
                                        and "HttpEndpoint" not in destination):
            raise InvalidArgumentException(
                f"Changing the destination type to or from {destination_name} "
                f"is not supported at this time.")

        # If this is a different type of destination configuration,
        # the existing configuration is reset first.
        if destination_name in destination:
            delivery_stream.destinations[destination_idx][
                destination_name].update(destination_config)
        else:
            delivery_stream.destinations[destination_idx] = {
                "destination_id": destination_id,
                destination_name: destination_config,
            }

        # Once S3 is updated to an ExtendedS3 destination, both remain in
        # the destination.  That means when one is updated, the other needs
        # to be updated as well.  The problem is that they don't have the
        # same fields.
        if destination_name == "ExtendedS3":
            delivery_stream.destinations[destination_idx][
                "S3"] = create_s3_destination_config(destination_config)
        elif destination_name == "S3" and "ExtendedS3" in destination:
            destination["ExtendedS3"] = {
                k: v
                for k, v in destination["S3"].items()
                if k in destination["ExtendedS3"]
            }

        # Increment version number and update the timestamp.
        delivery_stream.version_id = str(
            int(current_delivery_stream_version_id) + 1)
        delivery_stream.last_update_timestamp = datetime.now(
            timezone.utc).isoformat()

        # Unimplemented: processing of the "S3BackupMode" parameter.  Per the
        # documentation:  "You can update a delivery stream to enable Amazon
        # S3 backup if it is disabled.  If backup is enabled, you can't update
        # the delivery stream to disable it."

    def lookup_name_from_arn(self, arn):
        """Given an ARN, return the associated delivery stream name."""
        return self.delivery_streams.get(arn.split("/")[-1])

    def send_log_event(
        self,
        delivery_stream_arn,
        filter_name,
        log_group_name,
        log_stream_name,
        log_events,
    ):  # pylint:  disable=too-many-arguments
        """Send log events to a S3 bucket after encoding and gzipping it."""
        data = {
            "logEvents": log_events,
            "logGroup": log_group_name,
            "logStream": log_stream_name,
            "messageType": "DATA_MESSAGE",
            "owner": get_account_id(),
            "subscriptionFilters": [filter_name],
        }

        output = io.BytesIO()
        with GzipFile(fileobj=output, mode="w") as fhandle:
            fhandle.write(
                json.dumps(data, separators=(",", ":")).encode("utf-8"))
        gzipped_payload = b64encode(output.getvalue()).decode("utf-8")

        delivery_stream = self.lookup_name_from_arn(delivery_stream_arn)
        self.put_s3_records(
            delivery_stream.delivery_stream_name,
            delivery_stream.version_id,
            delivery_stream.destinations[0]["S3"],
            [{
                "Data": gzipped_payload
            }],
        )
Пример #9
0
class TimestreamWriteBackend(BaseBackend):
    def __init__(self, region_name, account_id):
        super().__init__(region_name, account_id)
        self.databases = dict()
        self.tagging_service = TaggingService()

    def create_database(self, database_name, kms_key_id, tags):
        database = TimestreamDatabase(self.region_name, database_name, kms_key_id)
        self.databases[database_name] = database
        self.tagging_service.tag_resource(database.arn, tags)
        return database

    def delete_database(self, database_name):
        del self.databases[database_name]

    def describe_database(self, database_name):
        if database_name not in self.databases:
            raise ResourceNotFound(f"The database {database_name} does not exist.")
        return self.databases[database_name]

    def list_databases(self):
        return self.databases.values()

    def update_database(self, database_name, kms_key_id):
        database = self.databases[database_name]
        database.update(kms_key_id=kms_key_id)
        return database

    def create_table(
        self,
        database_name,
        table_name,
        retention_properties,
        tags,
        magnetic_store_write_properties,
    ):
        database = self.describe_database(database_name)
        table = database.create_table(
            table_name, retention_properties, magnetic_store_write_properties
        )
        self.tagging_service.tag_resource(table.arn, tags)
        return table

    def delete_table(self, database_name, table_name):
        database = self.describe_database(database_name)
        database.delete_table(table_name)

    def describe_table(self, database_name, table_name):
        database = self.describe_database(database_name)
        table = database.describe_table(table_name)
        return table

    def list_tables(self, database_name):
        database = self.describe_database(database_name)
        tables = database.list_tables()
        return tables

    def update_table(
        self,
        database_name,
        table_name,
        retention_properties,
        magnetic_store_write_properties,
    ):
        database = self.describe_database(database_name)
        table = database.update_table(
            table_name, retention_properties, magnetic_store_write_properties
        )
        return table

    def write_records(self, database_name, table_name, records):
        database = self.describe_database(database_name)
        table = database.describe_table(table_name)
        table.write_records(records)

    def describe_endpoints(self):
        # https://docs.aws.amazon.com/timestream/latest/developerguide/Using-API.endpoint-discovery.how-it-works.html
        # Usually, the address look like this:
        # ingest-cell1.timestream.us-east-1.amazonaws.com
        # Where 'cell1' can be any number, 'cell2', 'cell3', etc - whichever endpoint happens to be available for that particular account
        # We don't implement a cellular architecture in Moto though, so let's keep it simple
        return {
            "Endpoints": [
                {
                    "Address": f"ingest.timestream.{self.region_name}.amazonaws.com",
                    "CachePeriodInMinutes": 1440,
                }
            ]
        }

    def list_tags_for_resource(self, resource_arn):
        return self.tagging_service.list_tags_for_resource(resource_arn)

    def tag_resource(self, resource_arn, tags):
        self.tagging_service.tag_resource(resource_arn, tags)

    def untag_resource(self, resource_arn, tag_keys):
        self.tagging_service.untag_resource_using_names(resource_arn, tag_keys)
Пример #10
0
class ServiceDiscoveryBackend(BaseBackend):
    """Implementation of ServiceDiscovery APIs."""

    def __init__(self, region_name=None):
        self.region_name = region_name
        self.operations = dict()
        self.namespaces = dict()
        self.services = dict()
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def list_namespaces(self):
        """
        Pagination or the Filters-parameter is not yet implemented
        """
        return self.namespaces.values()

    def create_http_namespace(self, name, creator_request_id, description, tags):
        namespace = Namespace(
            region=self.region_name,
            name=name,
            ns_type="HTTP",
            creator_request_id=creator_request_id,
            description=description,
            dns_properties={"SOA": {}},
            http_properties={"HttpName": name},
        )
        self.namespaces[namespace.id] = namespace
        if tags:
            self.tagger.tag_resource(namespace.arn, tags)
        operation_id = self._create_operation(
            "CREATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
        )
        return operation_id

    def _create_operation(self, op_type, targets):
        operation = Operation(operation_type=op_type, targets=targets)
        self.operations[operation.id] = operation
        operation_id = operation.id
        return operation_id

    def delete_namespace(self, namespace_id):
        if namespace_id not in self.namespaces:
            raise NamespaceNotFound(namespace_id)
        del self.namespaces[namespace_id]
        operation_id = self._create_operation(
            op_type="DELETE_NAMESPACE", targets={"NAMESPACE": namespace_id}
        )
        return operation_id

    def get_namespace(self, namespace_id):
        if namespace_id not in self.namespaces:
            raise NamespaceNotFound(namespace_id)
        return self.namespaces[namespace_id]

    def list_operations(self):
        """
        Pagination or the Filters-argument is not yet implemented
        """
        # Operations for namespaces will only be listed as long as namespaces exist
        self.operations = {
            op_id: op
            for op_id, op in self.operations.items()
            if op.targets.get("NAMESPACE") in self.namespaces
        }
        return self.operations.values()

    def get_operation(self, operation_id):
        if operation_id not in self.operations:
            raise OperationNotFound()
        return self.operations[operation_id]

    def tag_resource(self, resource_arn, tags):
        self.tagger.tag_resource(resource_arn, tags)

    def untag_resource(self, resource_arn, tag_keys):
        self.tagger.untag_resource_using_names(resource_arn, tag_keys)

    def list_tags_for_resource(self, resource_arn):
        return self.tagger.list_tags_for_resource(resource_arn)

    def create_private_dns_namespace(
        self, name, creator_request_id, description, vpc, tags, properties
    ):
        for namespace in self.namespaces.values():
            if namespace.vpc == vpc:
                raise ConflictingDomainExists(vpc)
        dns_properties = (properties or {}).get("DnsProperties", {})
        dns_properties["HostedZoneId"] = "hzi"
        namespace = Namespace(
            region=self.region_name,
            name=name,
            ns_type="DNS_PRIVATE",
            creator_request_id=creator_request_id,
            description=description,
            dns_properties=dns_properties,
            http_properties={},
            vpc=vpc,
        )
        self.namespaces[namespace.id] = namespace
        if tags:
            self.tagger.tag_resource(namespace.arn, tags)
        operation_id = self._create_operation(
            "CREATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
        )
        return operation_id

    def create_public_dns_namespace(
        self, name, creator_request_id, description, tags, properties
    ):
        dns_properties = (properties or {}).get("DnsProperties", {})
        dns_properties["HostedZoneId"] = "hzi"
        namespace = Namespace(
            region=self.region_name,
            name=name,
            ns_type="DNS_PUBLIC",
            creator_request_id=creator_request_id,
            description=description,
            dns_properties=dns_properties,
            http_properties={},
        )
        self.namespaces[namespace.id] = namespace
        if tags:
            self.tagger.tag_resource(namespace.arn, tags)
        operation_id = self._create_operation(
            "CREATE_NAMESPACE", targets={"NAMESPACE": namespace.id}
        )
        return operation_id

    def create_service(
        self,
        name,
        namespace_id,
        creator_request_id,
        description,
        dns_config,
        health_check_config,
        health_check_custom_config,
        tags,
        service_type,
    ):
        service = Service(
            region=self.region_name,
            name=name,
            namespace_id=namespace_id,
            description=description,
            creator_request_id=creator_request_id,
            dns_config=dns_config,
            health_check_config=health_check_config,
            health_check_custom_config=health_check_custom_config,
            service_type=service_type,
        )
        self.services[service.id] = service
        if tags:
            self.tagger.tag_resource(service.arn, tags)
        return service

    def get_service(self, service_id):
        if service_id not in self.services:
            raise ServiceNotFound(service_id)
        return self.services[service_id]

    def delete_service(self, service_id):
        self.services.pop(service_id, None)

    def list_services(self):
        """
        Pagination or the Filters-argument is not yet implemented
        """
        return self.services.values()

    def update_service(self, service_id, details):
        service = self.get_service(service_id)
        service.update(details=details)
        operation_id = self._create_operation(
            "UPDATE_SERVICE", targets={"SEVICE": service.id}
        )
        return operation_id
Пример #11
0
class EFSBackend(BaseBackend):
    """The backend manager of EFS resources.

    This is the state-machine for each region, tracking the file systems, mount targets,
    and eventually access points that are deployed. Creating, updating, and destroying
    such resources should always go through this class.
    """

    def __init__(self, region_name=None):
        super().__init__()
        self.region_name = region_name
        self.creation_tokens = set()
        self.access_points = dict()
        self.file_systems_by_id = {}
        self.mount_targets_by_id = {}
        self.next_markers = {}
        self.tagging_service = TaggingService()

    def reset(self):
        # preserve region
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def _mark_description(self, corpus, max_items):
        if max_items < len(corpus):
            new_corpus = corpus[max_items:]
            new_corpus_dict = [c.info_json() for c in new_corpus]
            new_hash = md5(json.dumps(new_corpus_dict).encode("utf-8"))
            next_marker = new_hash.hexdigest()
            self.next_markers[next_marker] = new_corpus
        else:
            next_marker = None
        return next_marker

    @property
    def ec2_backend(self):
        return ec2_backends[self.region_name]

    def create_file_system(
        self,
        creation_token,
        performance_mode,
        encrypted,
        kms_key_id,
        throughput_mode,
        provisioned_throughput_in_mibps,
        availability_zone_name,
        backup,
        tags,
    ):
        """Create a new EFS File System Volume.

        https://docs.aws.amazon.com/efs/latest/ug/API_CreateFileSystem.html
        """
        if not creation_token:
            raise ValueError("No creation token given.")
        if creation_token in self.creation_tokens:
            raise FileSystemAlreadyExists(creation_token)

        # Create a new file system ID:
        def make_id():
            return "fs-{}".format(get_random_hex())

        fsid = make_id()
        while fsid in self.file_systems_by_id:
            fsid = make_id()
        self.file_systems_by_id[fsid] = FileSystem(
            self.region_name,
            creation_token,
            fsid,
            context=self,
            performance_mode=performance_mode,
            encrypted=encrypted,
            kms_key_id=kms_key_id,
            throughput_mode=throughput_mode,
            provisioned_throughput_in_mibps=provisioned_throughput_in_mibps,
            availability_zone_name=availability_zone_name,
            backup=backup,
        )
        self.tag_resource(fsid, tags)
        self.creation_tokens.add(creation_token)
        return self.file_systems_by_id[fsid]

    def describe_file_systems(
        self, marker=None, max_items=10, creation_token=None, file_system_id=None
    ):
        """Describe all the EFS File Systems, or specific File Systems.

        https://docs.aws.amazon.com/efs/latest/ug/API_DescribeFileSystems.html
        """
        # Restrict the possible corpus of resules based on inputs.
        if creation_token and file_system_id:
            raise BadRequest(
                "Request cannot contain both a file system ID and a creation token."
            )
        elif creation_token:
            # Handle the creation token case.
            corpus = []
            for fs in self.file_systems_by_id.values():
                if fs.creation_token == creation_token:
                    corpus.append(fs)
        elif file_system_id:
            # Handle the case that a file_system_id is given.
            if file_system_id not in self.file_systems_by_id:
                raise FileSystemNotFound(file_system_id)
            corpus = [self.file_systems_by_id[file_system_id]]
        elif marker is not None:
            # Handle the case that a marker is given.
            if marker not in self.next_markers:
                raise BadRequest("Invalid Marker")
            corpus = self.next_markers[marker]
        else:
            # Handle the vanilla case.
            corpus = [fs for fs in self.file_systems_by_id.values()]

        # Handle the max_items parameter.
        file_systems = corpus[:max_items]
        next_marker = self._mark_description(corpus, max_items)
        return next_marker, file_systems

    def create_mount_target(
        self, file_system_id, subnet_id, ip_address=None, security_groups=None
    ):
        """Create a new EFS Mount Target for a given File System to a given subnet.

        Note that you can only create one mount target for each availability zone
        (which is implied by the subnet ID).

        https://docs.aws.amazon.com/efs/latest/ug/API_CreateMountTarget.html
        """
        # Get the relevant existing resources
        try:
            subnet = self.ec2_backend.get_subnet(subnet_id)
        except InvalidSubnetIdError:
            raise SubnetNotFound(subnet_id)
        if file_system_id not in self.file_systems_by_id:
            raise FileSystemNotFound(file_system_id)
        file_system = self.file_systems_by_id[file_system_id]

        # Validate the security groups.
        if security_groups:
            sg_lookup = {sg.id for sg in self.ec2_backend.describe_security_groups()}
            for sg_id in security_groups:
                if sg_id not in sg_lookup:
                    raise SecurityGroupNotFound(sg_id)

        # Create the new mount target
        mount_target = MountTarget(file_system, subnet, ip_address, security_groups)

        # Establish the network interface.
        network_interface = self.ec2_backend.create_network_interface(
            subnet, [mount_target.ip_address], group_ids=security_groups
        )
        mount_target.set_network_interface(network_interface)

        # Record the new mount target
        self.mount_targets_by_id[mount_target.mount_target_id] = mount_target
        return mount_target

    def describe_mount_targets(
        self, max_items, file_system_id, mount_target_id, access_point_id, marker
    ):
        """Describe the mount targets given an access point ID, mount target ID or a file system ID.

        https://docs.aws.amazon.com/efs/latest/ug/API_DescribeMountTargets.html
        """
        # Restrict the possible corpus of results based on inputs.
        if not (bool(file_system_id) ^ bool(mount_target_id) ^ bool(access_point_id)):
            raise BadRequest("Must specify exactly one mutually exclusive parameter.")

        if access_point_id:
            file_system_id = self.access_points[access_point_id].file_system_id

        if file_system_id:
            # Handle the case that a file_system_id is given.
            if file_system_id not in self.file_systems_by_id:
                raise FileSystemNotFound(file_system_id)
            corpus = [
                mt
                for mt in self.file_systems_by_id[file_system_id].iter_mount_targets()
            ]
        elif mount_target_id:
            if mount_target_id not in self.mount_targets_by_id:
                raise MountTargetNotFound(mount_target_id)
            # Handle mount target specification case.
            corpus = [self.mount_targets_by_id[mount_target_id]]

        # Handle the case that a marker is given. Note that the handling is quite
        # different from that in describe_file_systems.
        if marker is not None:
            if marker not in self.next_markers:
                raise BadRequest("Invalid Marker")
            corpus_mtids = {m.mount_target_id for m in corpus}
            marked_mtids = {m.mount_target_id for m in self.next_markers[marker]}
            mt_ids = corpus_mtids & marked_mtids
            corpus = [self.mount_targets_by_id[mt_id] for mt_id in mt_ids]

        # Handle the max_items parameter.
        mount_targets = corpus[:max_items]
        next_marker = self._mark_description(corpus, max_items)
        return next_marker, mount_targets

    def delete_file_system(self, file_system_id):
        """Delete the file system specified by the given file_system_id.

        Note that mount targets must be deleted first.

        https://docs.aws.amazon.com/efs/latest/ug/API_DeleteFileSystem.html
        """
        if file_system_id not in self.file_systems_by_id:
            raise FileSystemNotFound(file_system_id)

        file_system = self.file_systems_by_id[file_system_id]
        if file_system.number_of_mount_targets > 0:
            raise FileSystemInUse(
                "Must delete all mount targets before deleting file system."
            )

        del self.file_systems_by_id[file_system_id]
        self.creation_tokens.remove(file_system.creation_token)
        return

    def delete_mount_target(self, mount_target_id):
        """Delete a mount target specified by the given mount_target_id.

        Note that this will also delete a network interface.

        https://docs.aws.amazon.com/efs/latest/ug/API_DeleteMountTarget.html
        """
        if mount_target_id not in self.mount_targets_by_id:
            raise MountTargetNotFound(mount_target_id)

        mount_target = self.mount_targets_by_id[mount_target_id]
        self.ec2_backend.delete_network_interface(mount_target.network_interface_id)
        del self.mount_targets_by_id[mount_target_id]
        mount_target.clean_up()
        return

    def describe_backup_policy(self, file_system_id):
        backup_policy = self.file_systems_by_id[file_system_id].backup_policy
        if not backup_policy:
            raise PolicyNotFound("None")
        return backup_policy

    def put_lifecycle_configuration(self, file_system_id, policies):
        _, fss = self.describe_file_systems(file_system_id=file_system_id)
        file_system = fss[0]
        file_system.lifecycle_policies = policies

    def describe_lifecycle_configuration(self, file_system_id):
        _, fss = self.describe_file_systems(file_system_id=file_system_id)
        file_system = fss[0]
        return file_system.lifecycle_policies

    def describe_mount_target_security_groups(self, mount_target_id):
        if mount_target_id not in self.mount_targets_by_id:
            raise MountTargetNotFound(mount_target_id)

        mount_target = self.mount_targets_by_id[mount_target_id]
        return mount_target.security_groups

    def modify_mount_target_security_groups(self, mount_target_id, security_groups):
        if mount_target_id not in self.mount_targets_by_id:
            raise MountTargetNotFound(mount_target_id)

        mount_target = self.mount_targets_by_id[mount_target_id]
        mount_target.security_groups = security_groups

        self.ec2_backend.modify_network_interface_attribute(
            eni_id=mount_target.network_interface_id, group_ids=security_groups
        )

    def create_access_point(
        self, client_token, tags, file_system_id, posix_user, root_directory
    ):
        name = next((tag["Value"] for tag in tags if tag["Key"] == "Name"), None)
        access_point = AccessPoint(
            self.region_name,
            client_token,
            file_system_id,
            name,
            posix_user,
            root_directory,
            context=self,
        )
        self.tagging_service.tag_resource(access_point.access_point_id, tags)
        self.access_points[access_point.access_point_id] = access_point
        return access_point

    def describe_access_points(self, access_point_id):
        """
        Pagination is not yet implemented
        """
        if access_point_id:
            if access_point_id not in self.access_points:
                raise AccessPointNotFound(access_point_id)
            return [self.access_points[access_point_id]]
        return self.access_points.values()

    def delete_access_point(self, access_point_id):
        self.access_points.pop(access_point_id, None)

    def list_tags_for_resource(self, resource_id):
        return self.tagging_service.list_tags_for_resource(resource_id)["Tags"]

    def tag_resource(self, resource_id, tags):
        self.tagging_service.tag_resource(resource_id, tags)

    def untag_resource(self, resource_id, tag_keys):
        self.tagging_service.untag_resource_using_names(resource_id, tag_keys)
Пример #12
0
class DirectoryServiceBackend(BaseBackend):
    """Implementation of DirectoryService APIs."""
    def __init__(self, region_name=None):
        self.region_name = region_name
        self.directories = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "ds")

    @staticmethod
    def _verify_subnets(region, vpc_settings):
        """Verify subnets are valid, else raise an exception.

        If settings are valid, add AvailabilityZones to vpc_settings.
        """
        if len(vpc_settings["SubnetIds"]) != 2:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones.")

        # Subnet IDs are checked before the VPC ID.  The Subnet IDs must
        # be valid and in different availability zones.
        try:
            subnets = ec2_backends[region].get_all_subnets(
                subnet_ids=vpc_settings["SubnetIds"])
        except InvalidSubnetIdError as exc:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones.") from exc

        regions = [subnet.availability_zone for subnet in subnets]
        if regions[0] == regions[1]:
            raise ClientException(
                "Invalid subnet ID(s). The two subnets must be in "
                "different Availability Zones.")

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_settings["VpcId"] not in [x.id for x in vpcs]:
            raise ClientException("Invalid VPC ID.")
        vpc_settings["AvailabilityZones"] = regions

    def connect_directory(
        self,
        region,
        name,
        short_name,
        password,
        description,
        size,
        connect_settings,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Create a fake AD Connector."""
        if len(self.directories) > Directory.CONNECTED_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CONNECTED_DIRECTORIES_LIMIT} directories may be created"
            )

        validate_args([
            ("password", password),
            ("size", size),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            (
                "connectSettings.vpcSettings.subnetIds",
                connect_settings["SubnetIds"],
            ),
            (
                "connectSettings.customerUserName",
                connect_settings["CustomerUserName"],
            ),
            ("connectSettings.customerDnsIps",
             connect_settings["CustomerDnsIps"]),
        ])
        # ConnectSettings and VpcSettings both have a VpcId and Subnets.
        self._verify_subnets(region, connect_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "ADConnector",
            size=size,
            connect_settings=connect_settings,
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def create_directory(self, region, name, short_name, password, description,
                         size, vpc_settings, tags):  # pylint: disable=too-many-arguments
        """Create a fake Simple Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created"
            )

        # botocore doesn't look for missing vpc_settings, but boto3 does.
        if not vpc_settings:
            raise InvalidParameterException("VpcSettings must be specified.")
        validate_args([
            ("password", password),
            ("size", size),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]),
        ])
        self._verify_subnets(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "SimpleAD",
            size=size,
            vpc_settings=vpc_settings,
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def _validate_directory_id(self, directory_id):
        """Raise an exception if the directory id is invalid or unknown."""
        # Validation of ID takes precedence over a check for its existence.
        validate_args([("directoryId", directory_id)])
        if directory_id not in self.directories:
            raise EntityDoesNotExistException(
                f"Directory {directory_id} does not exist")

    def create_alias(self, directory_id, alias):
        """Create and assign an alias to a directory."""
        self._validate_directory_id(directory_id)

        # The default alias name is the same as the directory name.  Check
        # whether this directory was already given an alias.
        directory = self.directories[directory_id]
        if directory.alias != directory_id:
            raise InvalidParameterException(
                "The directory in the request already has an alias. That "
                "alias must be deleted before a new alias can be created.")

        # Is the alias already in use?
        if alias in [x.alias for x in self.directories.values()]:
            raise EntityAlreadyExistsException(
                f"Alias '{alias}' already exists.")
        validate_args([("alias", alias)])

        directory.update_alias(alias)
        return {"DirectoryId": directory_id, "Alias": alias}

    def create_microsoft_ad(
        self,
        region,
        name,
        short_name,
        password,
        description,
        vpc_settings,
        edition,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Create a fake Microsoft Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_MICROSOFT_AD_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_MICROSOFT_AD_LIMIT} directories may be created"
            )

        # boto3 looks for missing vpc_settings for create_microsoft_ad().
        validate_args([
            ("password", password),
            ("edition", edition),
            ("name", name),
            ("description", description),
            ("shortName", short_name),
            ("vpcSettings.subnetIds", vpc_settings["SubnetIds"]),
        ])
        self._verify_subnets(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            region,
            name,
            password,
            "MicrosoftAD",
            vpc_settings=vpc_settings,
            short_name=short_name,
            description=description,
            edition=edition,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def delete_directory(self, directory_id):
        """Delete directory with the matching ID."""
        self._validate_directory_id(directory_id)
        self.tagger.delete_all_tags_for_resource(directory_id)
        self.directories.pop(directory_id)
        return directory_id

    def disable_sso(self, directory_id, username=None, password=None):
        """Disable single-sign on for a directory."""
        self._validate_directory_id(directory_id)
        validate_args([("ssoPassword", password), ("userName", username)])
        directory = self.directories[directory_id]
        directory.enable_sso(False)

    def enable_sso(self, directory_id, username=None, password=None):
        """Enable single-sign on for a directory."""
        self._validate_directory_id(directory_id)
        validate_args([("ssoPassword", password), ("userName", username)])

        directory = self.directories[directory_id]
        if directory.alias == directory_id:
            raise ClientException(
                f"An alias is required before enabling SSO. DomainId={directory_id}"
            )

        directory = self.directories[directory_id]
        directory.enable_sso(True)

    @paginate(pagination_model=PAGINATION_MODEL)
    def describe_directories(self,
                             directory_ids=None,
                             next_token=None,
                             limit=0):  # pylint: disable=unused-argument
        """Return info on all directories or directories with matching IDs."""
        for directory_id in directory_ids or self.directories:
            self._validate_directory_id(directory_id)

        directories = list(self.directories.values())
        if directory_ids:
            directories = [
                x for x in directories if x.directory_id in directory_ids
            ]
        return sorted(directories, key=lambda x: x.launch_time)

    def get_directory_limits(self):
        """Return hard-coded limits for the directories."""
        counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0}
        for directory in self.directories.values():
            if directory.directory_type == "SimpleAD":
                counts["SimpleAD"] += 1
            elif directory.directory_type in [
                    "MicrosoftAD", "SharedMicrosoftAD"
            ]:
                counts["MicrosoftAD"] += 1
            elif directory.directory_type == "ADConnector":
                counts["ConnectedAD"] += 1

        return {
            "CloudOnlyDirectoriesLimit":
            Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyDirectoriesCurrentCount":
            counts["SimpleAD"],
            "CloudOnlyDirectoriesLimitReached":
            counts["SimpleAD"] == Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyMicrosoftADLimit":
            Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "CloudOnlyMicrosoftADCurrentCount":
            counts["MicrosoftAD"],
            "CloudOnlyMicrosoftADLimitReached":
            counts["MicrosoftAD"] == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "ConnectedDirectoriesLimit":
            Directory.CONNECTED_DIRECTORIES_LIMIT,
            "ConnectedDirectoriesCurrentCount":
            counts["ConnectedAD"],
            "ConnectedDirectoriesLimitReached":
            counts["ConnectedAD"] == Directory.CONNECTED_DIRECTORIES_LIMIT,
        }

    def add_tags_to_resource(self, resource_id, tags):
        """Add or overwrite one or more tags for specified directory."""
        self._validate_directory_id(resource_id)
        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise TagLimitExceededException("Tag limit exceeded")
        self.tagger.tag_resource(resource_id, tags)

    def remove_tags_from_resource(self, resource_id, tag_keys):
        """Removes tags from a directory."""
        self._validate_directory_id(resource_id)
        self.tagger.untag_resource_using_names(resource_id, tag_keys)

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(
        self,
        resource_id,
        next_token=None,
        limit=None,
    ):  # pylint: disable=unused-argument
        """List all tags on a directory."""
        self._validate_directory_id(resource_id)
        return self.tagger.list_tags_for_resource(resource_id).get("Tags")
Пример #13
0
class EventsBackend(BaseBackend):
    ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$")
    STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$")

    def __init__(self, region_name):
        self.rules = {}
        # This array tracks the order in which the rules have been added, since
        # 2.6 doesn't have OrderedDicts.
        self.rules_order = []
        self.next_tokens = {}
        self.region_name = region_name
        self.event_buses = {}
        self.event_sources = {}
        self.archives = {}
        self.replays = {}
        self.tagger = TaggingService()

        self._add_default_event_bus()

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def _add_default_event_bus(self):
        self.event_buses["default"] = EventBus(self.region_name, "default")

    def _get_rule_by_index(self, i):
        return self.rules.get(self.rules_order[i])

    def _gen_next_token(self, index):
        token = os.urandom(128).encode("base64")
        self.next_tokens[token] = index
        return token

    def _process_token_and_limits(self,
                                  array_len,
                                  next_token=None,
                                  limit=None):
        start_index = 0
        end_index = array_len
        new_next_token = None

        if next_token:
            start_index = self.next_tokens.pop(next_token, 0)

        if limit is not None:
            new_end_index = start_index + int(limit)
            if new_end_index < end_index:
                end_index = new_end_index
                new_next_token = self._gen_next_token(end_index)

        return start_index, end_index, new_next_token

    def _get_event_bus(self, name):
        event_bus_name = name.split("/")[-1]

        event_bus = self.event_buses.get(event_bus_name)
        if not event_bus:
            raise ResourceNotFoundException(
                "Event bus {} does not exist.".format(event_bus_name))

        return event_bus

    def _get_replay(self, name):
        replay = self.replays.get(name)
        if not replay:
            raise ResourceNotFoundException(
                "Replay {} does not exist.".format(name))

        return replay

    def delete_rule(self, name):
        self.rules_order.pop(self.rules_order.index(name))
        arn = self.rules.get(name).arn
        if self.tagger.has_tags(arn):
            self.tagger.delete_all_tags_for_resource(arn)
        return self.rules.pop(name) is not None

    def describe_rule(self, name):
        return self.rules.get(name)

    def disable_rule(self, name):
        if name in self.rules:
            self.rules[name].disable()
            return True

        return False

    def enable_rule(self, name):
        if name in self.rules:
            self.rules[name].enable()
            return True

        return False

    def list_rule_names_by_target(self,
                                  target_arn,
                                  next_token=None,
                                  limit=None):
        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            for target in rule.targets:
                if target["Arn"] == target_arn:
                    matching_rules.append(rule.name)

        return_obj["RuleNames"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_rules(self, prefix=None, next_token=None, limit=None):
        match_string = ".*"
        if prefix is not None:
            match_string = "^" + prefix + match_string

        match_regex = re.compile(match_string)

        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            if match_regex.match(rule.name):
                matching_rules.append(rule)

        return_obj["Rules"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_targets_by_rule(self, rule, next_token=None, limit=None):
        # We'll let a KeyError exception be thrown for response to handle if
        # rule doesn't exist.
        rule = self.rules[rule]

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(rule.targets), next_token, limit)

        returned_targets = []
        return_obj = {}

        for i in range(start_index, end_index):
            returned_targets.append(rule.targets[i])

        return_obj["Targets"] = returned_targets
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def update_rule(self, rule, **kwargs):
        rule.event_pattern = kwargs.get("EventPattern") or rule.event_pattern
        rule.schedule_exp = kwargs.get(
            "ScheduleExpression") or rule.schedule_exp
        rule.state = kwargs.get("State") or rule.state
        rule.description = kwargs.get("Description") or rule.description
        rule.role_arn = kwargs.get("RoleArn") or rule.role_arn
        rule.event_bus_name = kwargs.get("EventBusName") or rule.event_bus_name

    def put_rule(self, name, **kwargs):
        if kwargs.get("ScheduleExpression"
                      ) and kwargs.get("EventBusName") != "default":
            raise ValidationException(
                "ScheduleExpression is supported only on the default event bus."
            )

        if name in self.rules:
            self.update_rule(self.rules[name], **kwargs)
            new_rule = self.rules[name]
        else:
            new_rule = Rule(name, self.region_name, **kwargs)
            self.rules[new_rule.name] = new_rule
            self.rules_order.append(new_rule.name)
        return new_rule

    def put_targets(self, name, event_bus_name, targets):
        # super simple ARN check
        invalid_arn = next(
            (target["Arn"] for target in targets
             if not re.match(r"arn:[\d\w:\-/]*", target["Arn"])),
            None,
        )
        if invalid_arn:
            raise ValidationException(
                "Parameter {} is not valid. "
                "Reason: Provided Arn is not in correct format.".format(
                    invalid_arn))

        for target in targets:
            arn = target["Arn"]

            if (":sqs:" in arn and arn.endswith(".fifo")
                    and not target.get("SqsParameters")):
                raise ValidationException(
                    "Parameter(s) SqsParameters must be specified for target: {}."
                    .format(target["Id"]))

        rule = self.rules.get(name)

        if not rule:
            raise ResourceNotFoundException(
                "Rule {0} does not exist on EventBus {1}.".format(
                    name, event_bus_name))

        rule.put_targets(targets)

    def put_events(self, events):
        num_events = len(events)

        if num_events > 10:
            # the exact error text is longer, the Value list consists of all the put events
            raise ValidationException(
                "1 validation error detected: "
                "Value '[PutEventsRequestEntry]' at 'entries' failed to satisfy constraint: "
                "Member must have length less than or equal to 10")

        entries = []
        for event in events:
            if "Source" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Source is not valid. Reason: Source is a required argument.",
                })
            elif "DetailType" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter DetailType is not valid. Reason: DetailType is a required argument.",
                })
            elif "Detail" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Detail is not valid. Reason: Detail is a required argument.",
                })
            else:
                try:
                    json.loads(event["Detail"])
                except ValueError:  # json.JSONDecodeError exists since Python 3.5
                    entries.append({
                        "ErrorCode": "MalformedDetail",
                        "ErrorMessage": "Detail is malformed.",
                    })
                    continue

                event_id = str(uuid4())
                entries.append({"EventId": event_id})

                # if 'EventBusName' is not especially set, it will be sent to the default one
                event_bus_name = event.get("EventBusName", "default")

                for rule in self.rules.values():
                    rule.send_to_targets(
                        event_bus_name,
                        {
                            "version": "0",
                            "id": event_id,
                            "detail-type": event["DetailType"],
                            "source": event["Source"],
                            "account": ACCOUNT_ID,
                            "time": event.get("Time",
                                              unix_time(datetime.utcnow())),
                            "region": self.region_name,
                            "resources": event.get("Resources", []),
                            "detail": json.loads(event["Detail"]),
                        },
                    )

        return entries

    def remove_targets(self, name, event_bus_name, ids):
        rule = self.rules.get(name)

        if not rule:
            raise ResourceNotFoundException(
                "Rule {0} does not exist on EventBus {1}.".format(
                    name, event_bus_name))

        rule.remove_targets(ids)

    def test_event_pattern(self):
        raise NotImplementedError()

    def put_permission(self, event_bus_name, action, principal, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if action is None or action != "events:PutEvents":
            raise JsonRESTError(
                "ValidationException",
                "Provided value in parameter 'action' is not supported.",
            )

        if principal is None or self.ACCOUNT_ID.match(principal) is None:
            raise JsonRESTError("InvalidParameterValue",
                                r"Principal must match ^(\d{1,12}|\*)$")

        if statement_id is None or self.STATEMENT_ID.match(
                statement_id) is None:
            raise JsonRESTError(
                "InvalidParameterValue",
                r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$")

        event_bus._permissions[statement_id] = {
            "Action": action,
            "Principal": principal,
        }

    def remove_permission(self, event_bus_name, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if not len(event_bus._permissions):
            raise JsonRESTError("ResourceNotFoundException",
                                "EventBus does not have a policy.")

        if not event_bus._permissions.pop(statement_id, None):
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Statement with the provided id does not exist.",
            )

    def describe_event_bus(self, name):
        if not name:
            name = "default"

        event_bus = self._get_event_bus(name)

        return event_bus

    def create_event_bus(self, name, event_source_name=None):
        if name in self.event_buses:
            raise JsonRESTError(
                "ResourceAlreadyExistsException",
                "Event bus {} already exists.".format(name),
            )

        if not event_source_name and "/" in name:
            raise JsonRESTError("ValidationException",
                                "Event bus name must not contain '/'.")

        if event_source_name and event_source_name not in self.event_sources:
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Event source {} does not exist.".format(event_source_name),
            )

        self.event_buses[name] = EventBus(self.region_name, name)

        return self.event_buses[name]

    def list_event_buses(self, name_prefix):
        if name_prefix:
            return [
                event_bus for event_bus in self.event_buses.values()
                if event_bus.name.startswith(name_prefix)
            ]

        return list(self.event_buses.values())

    def delete_event_bus(self, name):
        if name == "default":
            raise JsonRESTError("ValidationException",
                                "Cannot delete event bus default.")
        self.event_buses.pop(name, None)

    def list_tags_for_resource(self, arn):
        name = arn.split("/")[-1]
        if name in self.rules:
            return self.tagger.list_tags_for_resource(self.rules[name].arn)
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def tag_resource(self, arn, tags):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.tag_resource(self.rules[name].arn, tags)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def untag_resource(self, arn, tag_names):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.untag_resource_using_names(self.rules[name].arn,
                                                   tag_names)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def create_archive(self, name, source_arn, description, event_pattern,
                       retention):
        if len(name) > 48:
            raise ValidationException(
                " 1 validation error detected: "
                "Value '{}' at 'archiveName' failed to satisfy constraint: "
                "Member must have length less than or equal to 48".format(
                    name))

        event_bus = self._get_event_bus(source_arn)

        if name in self.archives:
            raise ResourceAlreadyExistsException(
                "Archive {} already exists.".format(name))

        archive = Archive(self.region_name, name, source_arn, description,
                          event_pattern, retention)

        rule_event_pattern = json.loads(event_pattern or "{}")
        rule_event_pattern["replay-name"] = [{"exists": False}]

        rule = self.put_rule(
            "Events-Archive-{}".format(name), **{
                "EventPattern": json.dumps(rule_event_pattern),
                "EventBusName": event_bus.name,
                "ManagedBy": "prod.vhs.events.aws.internal",
            })
        self.put_targets(
            rule.name,
            rule.event_bus_name,
            [{
                "Id": rule.name,
                "Arn": "arn:aws:events:{}:::".format(self.region_name),
                "InputTransformer": {
                    "InputPathsMap": {},
                    "InputTemplate":
                    json.dumps({
                        "archive-arn":
                        "{0}:{1}".format(archive.arn, archive.uuid),
                        "event":
                        "<aws.events.event.json>",
                        "ingestion-time":
                        "<aws.events.event.ingestion-time>",
                    }),
                },
            }],
        )

        self.archives[name] = archive

        return archive

    def describe_archive(self, name):
        archive = self.archives.get(name)

        if not archive:
            raise ResourceNotFoundException(
                "Archive {} does not exist.".format(name))

        return archive.describe()

    def list_archives(self, name_prefix, source_arn, state):
        if [name_prefix, source_arn, state].count(None) < 2:
            raise ValidationException(
                "At most one filter is allowed for ListArchives. "
                "Use either : State, EventSourceArn, or NamePrefix.")

        if state and state not in Archive.VALID_STATES:
            raise ValidationException(
                "1 validation error detected: "
                "Value '{0}' at 'state' failed to satisfy constraint: "
                "Member must satisfy enum value set: "
                "[{1}]".format(state, ", ".join(Archive.VALID_STATES)))

        if [name_prefix, source_arn, state].count(None) == 3:
            return [
                archive.describe_short() for archive in self.archives.values()
            ]

        result = []

        for archive in self.archives.values():
            if name_prefix and archive.name.startswith(name_prefix):
                result.append(archive.describe_short())
            elif source_arn and archive.source_arn == source_arn:
                result.append(archive.describe_short())
            elif state and archive.state == state:
                result.append(archive.describe_short())

        return result

    def update_archive(self, name, description, event_pattern, retention):
        archive = self.archives.get(name)

        if not archive:
            raise ResourceNotFoundException(
                "Archive {} does not exist.".format(name))

        archive.update(description, event_pattern, retention)

        return {
            "ArchiveArn": archive.arn,
            "CreationTime": archive.creation_time,
            "State": archive.state,
        }

    def delete_archive(self, name):
        archive = self.archives.get(name)

        if not archive:
            raise ResourceNotFoundException(
                "Archive {} does not exist.".format(name))

        archive.delete(self.region_name)

    def start_replay(self, name, description, source_arn, start_time, end_time,
                     destination):
        event_bus_arn = destination["Arn"]
        event_bus_arn_pattern = r"^arn:aws:events:[a-zA-Z0-9-]+:\d{12}:event-bus/"
        if not re.match(event_bus_arn_pattern, event_bus_arn):
            raise ValidationException(
                "Parameter Destination.Arn is not valid. "
                "Reason: Must contain an event bus ARN.")

        self._get_event_bus(event_bus_arn)

        archive_name = source_arn.split("/")[-1]
        archive = self.archives.get(archive_name)
        if not archive:
            raise ValidationException(
                "Parameter EventSourceArn is not valid. "
                "Reason: Archive {} does not exist.".format(archive_name))

        if event_bus_arn != archive.source_arn:
            raise ValidationException(
                "Parameter Destination.Arn is not valid. "
                "Reason: Cross event bus replay is not permitted.")

        if start_time > end_time:
            raise ValidationException(
                "Parameter EventEndTime is not valid. "
                "Reason: EventStartTime must be before EventEndTime.")

        if name in self.replays:
            raise ResourceAlreadyExistsException(
                "Replay {} already exists.".format(name))

        replay = Replay(
            self.region_name,
            name,
            description,
            source_arn,
            start_time,
            end_time,
            destination,
        )

        self.replays[name] = replay

        replay.replay_events(archive)

        return {
            "ReplayArn": replay.arn,
            "ReplayStartTime": replay.start_time,
            "State": ReplayState.STARTING.
            value,  # the replay will be done before returning the response
        }

    def describe_replay(self, name):
        replay = self._get_replay(name)

        return replay.describe()

    def list_replays(self, name_prefix, source_arn, state):
        if [name_prefix, source_arn, state].count(None) < 2:
            raise ValidationException(
                "At most one filter is allowed for ListReplays. "
                "Use either : State, EventSourceArn, or NamePrefix.")

        valid_states = sorted([item.value for item in ReplayState])
        if state and state not in valid_states:
            raise ValidationException(
                "1 validation error detected: "
                "Value '{0}' at 'state' failed to satisfy constraint: "
                "Member must satisfy enum value set: "
                "[{1}]".format(state, ", ".join(valid_states)))

        if [name_prefix, source_arn, state].count(None) == 3:
            return [
                replay.describe_short() for replay in self.replays.values()
            ]

        result = []

        for replay in self.replays.values():
            if name_prefix and replay.name.startswith(name_prefix):
                result.append(replay.describe_short())
            elif source_arn and replay.source_arn == source_arn:
                result.append(replay.describe_short())
            elif state and replay.state == state:
                result.append(replay.describe_short())

        return result

    def cancel_replay(self, name):
        replay = self._get_replay(name)

        # replays in the state 'COMPLETED' can't be canceled,
        # but the implementation is done synchronously,
        # so they are done right after the start
        if replay.state not in [
                ReplayState.STARTING,
                ReplayState.RUNNING,
                ReplayState.COMPLETED,
        ]:
            raise IllegalStatusException(
                "Replay {} is not in a valid state for this operation.".format(
                    name))

        replay.state = ReplayState.CANCELLED

        return {"ReplayArn": replay.arn, "State": ReplayState.CANCELLING.value}
Пример #14
0
class Route53ResolverBackend(BaseBackend):
    """Implementation of Route53Resolver APIs."""

    def __init__(self, region_name, account_id):
        super().__init__(region_name, account_id)
        self.resolver_endpoints = {}  # Key is self-generated ID (endpoint_id)
        self.resolver_rules = {}  # Key is self-generated ID (rule_id)
        self.resolver_rule_associations = {}  # Key is resolver_rule_association_id)
        self.tagger = TaggingService()

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "route53resolver"
        )

    def associate_resolver_rule(self, region, resolver_rule_id, name, vpc_id):
        validate_args(
            [("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)]
        )

        associations = [
            x for x in self.resolver_rule_associations.values() if x.region == region
        ]
        if len(associations) > ResolverRuleAssociation.MAX_RULE_ASSOCIATIONS_PER_REGION:
            # This error message was not verified to be the same for AWS.
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-rule-association'"
            )

        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist."
            )

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_id not in [x.id for x in vpcs]:
            raise InvalidParameterException(f"The vpc ID '{vpc_id}' does not exist")

        # Can't duplicate resolver rule, vpc id associations.
        for association in self.resolver_rule_associations.values():
            if (
                resolver_rule_id == association.resolver_rule_id
                and vpc_id == association.vpc_id
            ):
                raise InvalidRequestException(
                    f"Cannot associate rules with same domain name with same "
                    f"VPC. Conflict with resolver rule '{resolver_rule_id}'"
                )

        rule_association_id = f"rslvr-rrassoc-{get_random_hex(17)}"
        rule_association = ResolverRuleAssociation(
            region, rule_association_id, resolver_rule_id, vpc_id, name
        )
        self.resolver_rule_associations[rule_association_id] = rule_association
        return rule_association

    @staticmethod
    def _verify_subnet_ips(region, ip_addresses, initial=True):
        """Perform additional checks on the IPAddresses.

        NOTE: This does not include IPv6 addresses.
        """
        # only initial endpoint creation requires atleast two ip addresses
        if initial:
            if len(ip_addresses) < 2:
                raise InvalidRequestException(
                    "Resolver endpoint needs to have at least 2 IP addresses"
                )

        subnets = defaultdict(set)
        for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]:
            try:
                subnet_info = ec2_backends[region].get_all_subnets(
                    subnet_ids=[subnet_id]
                )[0]
            except InvalidSubnetIdError as exc:
                raise InvalidParameterException(
                    f"The subnet ID '{subnet_id}' does not exist"
                ) from exc

            # IP in IPv4 CIDR range and not reserved?
            if ip_address(ip_addr) in subnet_info.reserved_ips or ip_address(
                ip_addr
            ) not in ip_network(subnet_info.cidr_block):
                raise InvalidRequestException(
                    f"IP address '{ip_addr}' is either not in subnet "
                    f"'{subnet_id}' CIDR range or is reserved"
                )

            if ip_addr in subnets[subnet_id]:
                raise ResourceExistsException(
                    f"The IP address '{ip_addr}' in subnet '{subnet_id}' is already in use"
                )
            subnets[subnet_id].add(ip_addr)

    @staticmethod
    def _verify_security_group_ids(region, security_group_ids):
        """Perform additional checks on the security groups."""
        if len(security_group_ids) > 10:
            raise InvalidParameterException("Maximum of 10 security groups are allowed")

        for group_id in security_group_ids:
            if not group_id.startswith("sg-"):
                raise InvalidParameterException(
                    f"Malformed security group ID: Invalid id: '{group_id}' "
                    f"(expecting 'sg-...')"
                )
            try:
                ec2_backends[region].describe_security_groups(group_ids=[group_id])
            except InvalidSecurityGroupNotFoundError as exc:
                raise ResourceNotFoundException(
                    f"The security group '{group_id}' does not exist"
                ) from exc

    def create_resolver_endpoint(
        self,
        region,
        creator_request_id,
        name,
        security_group_ids,
        direction,
        ip_addresses,
        tags,
    ):  # pylint: disable=too-many-arguments
        """
        Return description for a newly created resolver endpoint.

        NOTE:  IPv6 IPs are currently not being filtered when
        calculating the create_resolver_endpoint() IpAddresses.
        """
        validate_args(
            [
                ("creatorRequestId", creator_request_id),
                ("direction", direction),
                ("ipAddresses", ip_addresses),
                ("name", name),
                ("securityGroupIds", security_group_ids),
                ("ipAddresses.subnetId", ip_addresses),
            ]
        )
        errmsg = self.tagger.validate_tags(
            tags or [], limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT
        )
        if errmsg:
            raise TagValidationException(errmsg)

        endpoints = [x for x in self.resolver_endpoints.values() if x.region == region]
        if len(endpoints) > ResolverEndpoint.MAX_ENDPOINTS_PER_REGION:
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-endpoints'"
            )

        for x in ip_addresses:
            if not x.get("Ip"):
                subnet_info = ec2_backends[region].get_all_subnets(
                    subnet_ids=[x["SubnetId"]]
                )[0]
                x["Ip"] = subnet_info.get_available_subnet_ip(self)

        self._verify_subnet_ips(region, ip_addresses)
        self._verify_security_group_ids(region, security_group_ids)
        if creator_request_id in [
            x.creator_request_id for x in self.resolver_endpoints.values()
        ]:
            raise ResourceExistsException(
                f"Resolver endpoint with creator request ID "
                f"'{creator_request_id}' already exists"
            )

        endpoint_id = (
            f"rslvr-{'in' if direction == 'INBOUND' else 'out'}-{get_random_hex(17)}"
        )
        resolver_endpoint = ResolverEndpoint(
            region,
            endpoint_id,
            creator_request_id,
            security_group_ids,
            direction,
            ip_addresses,
            name,
        )

        self.resolver_endpoints[endpoint_id] = resolver_endpoint
        self.tagger.tag_resource(resolver_endpoint.arn, tags or [])
        return resolver_endpoint

    def create_resolver_rule(
        self,
        region,
        creator_request_id,
        name,
        rule_type,
        domain_name,
        target_ips,
        resolver_endpoint_id,
        tags,
    ):  # pylint: disable=too-many-arguments
        """Return description for a newly created resolver rule."""
        validate_args(
            [
                ("creatorRequestId", creator_request_id),
                ("ruleType", rule_type),
                ("domainName", domain_name),
                ("name", name),
                *[("targetIps.port", x) for x in target_ips],
                ("resolverEndpointId", resolver_endpoint_id),
            ]
        )
        errmsg = self.tagger.validate_tags(
            tags or [], limit=ResolverRule.MAX_TAGS_PER_RESOLVER_RULE
        )
        if errmsg:
            raise TagValidationException(errmsg)

        rules = [x for x in self.resolver_rules.values() if x.region == region]
        if len(rules) > ResolverRule.MAX_RULES_PER_REGION:
            # Did not verify that this is the actual error message.
            raise LimitExceededException(
                f"Account '{get_account_id()}' has exceeded 'max-rules'"
            )

        # Per the AWS documentation and as seen with the AWS console, target
        # ips are only relevant when the value of Rule is FORWARD.  However,
        # boto3 ignores this condition and so shall we.

        for ip_addr in [x["Ip"] for x in target_ips]:
            try:
                # boto3 fails with an InternalServiceException if IPv6
                # addresses are used, which isn't helpful.
                if not isinstance(ip_address(ip_addr), IPv4Address):
                    raise InvalidParameterException(
                        f"Only IPv4 addresses may be used: '{ip_addr}'"
                    )
            except ValueError as exc:
                raise InvalidParameterException(
                    f"Invalid IP address: '{ip_addr}'"
                ) from exc

        # The boto3 documentation indicates that ResolverEndpoint is
        # optional, as does the AWS documention.  But if resolver_endpoint_id
        # is set to None or an empty string, it results in boto3 raising
        # a ParamValidationError either regarding the type or len of string.
        if resolver_endpoint_id:
            if resolver_endpoint_id not in [
                x.id for x in self.resolver_endpoints.values()
            ]:
                raise ResourceNotFoundException(
                    f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist."
                )

            if rule_type == "SYSTEM":
                raise InvalidRequestException(
                    "Cannot specify resolver endpoint ID and target IP "
                    "for SYSTEM type resolver rule"
                )

        if creator_request_id in [
            x.creator_request_id for x in self.resolver_rules.values()
        ]:
            raise ResourceExistsException(
                f"Resolver rule with creator request ID "
                f"'{creator_request_id}' already exists"
            )

        rule_id = f"rslvr-rr-{get_random_hex(17)}"
        resolver_rule = ResolverRule(
            region,
            rule_id,
            creator_request_id,
            rule_type,
            domain_name,
            target_ips,
            resolver_endpoint_id,
            name,
        )

        self.resolver_rules[rule_id] = resolver_rule
        self.tagger.tag_resource(resolver_rule.arn, tags or [])
        return resolver_rule

    def _validate_resolver_endpoint_id(self, resolver_endpoint_id):
        """Raise an exception if the id is invalid or unknown."""
        validate_args([("resolverEndpointId", resolver_endpoint_id)])
        if resolver_endpoint_id not in self.resolver_endpoints:
            raise ResourceNotFoundException(
                f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist"
            )

    def delete_resolver_endpoint(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)

        # Can't delete an endpoint if there are rules associated with it.
        rules = [
            x.id
            for x in self.resolver_rules.values()
            if x.resolver_endpoint_id == resolver_endpoint_id
        ]
        if rules:
            raise InvalidRequestException(
                f"Cannot delete resolver endpoint unless its related resolver "
                f"rules are deleted.  The following rules still exist for "
                f"this resolver endpoint:  {','.join(rules)}"
            )

        self.tagger.delete_all_tags_for_resource(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints.pop(resolver_endpoint_id)
        resolver_endpoint.delete_eni()
        resolver_endpoint.status = "DELETING"
        resolver_endpoint.status_message = resolver_endpoint.status_message.replace(
            "Successfully created", "Deleting"
        )
        return resolver_endpoint

    def _validate_resolver_rule_id(self, resolver_rule_id):
        """Raise an exception if the id is invalid or unknown."""
        validate_args([("resolverRuleId", resolver_rule_id)])
        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist"
            )

    def delete_resolver_rule(self, resolver_rule_id):
        self._validate_resolver_rule_id(resolver_rule_id)

        # Can't delete an rule unless VPC's are disassociated.
        associations = [
            x.id
            for x in self.resolver_rule_associations.values()
            if x.resolver_rule_id == resolver_rule_id
        ]
        if associations:
            raise ResourceInUseException(
                "Please disassociate this resolver rule from VPC first "
                "before deleting"
            )

        self.tagger.delete_all_tags_for_resource(resolver_rule_id)
        resolver_rule = self.resolver_rules.pop(resolver_rule_id)
        resolver_rule.status = "DELETING"
        resolver_rule.status_message = resolver_rule.status_message.replace(
            "Successfully created", "Deleting"
        )
        return resolver_rule

    def disassociate_resolver_rule(self, resolver_rule_id, vpc_id):
        validate_args([("resolverRuleId", resolver_rule_id), ("vPCId", vpc_id)])

        # Non-existent rule or vpc ids?
        if resolver_rule_id not in self.resolver_rules:
            raise ResourceNotFoundException(
                f"Resolver rule with ID '{resolver_rule_id}' does not exist"
            )

        # Find the matching association for this rule and vpc.
        rule_association_id = None
        for association in self.resolver_rule_associations.values():
            if (
                resolver_rule_id == association.resolver_rule_id
                and vpc_id == association.vpc_id
            ):
                rule_association_id = association.id
                break
        else:
            raise ResourceNotFoundException(
                f"Resolver Rule Association between Resolver Rule "
                f"'{resolver_rule_id}' and VPC '{vpc_id}' does not exist"
            )

        rule_association = self.resolver_rule_associations.pop(rule_association_id)
        rule_association.status = "DELETING"
        rule_association.status_message = "Deleting Association"
        return rule_association

    def get_resolver_endpoint(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        return self.resolver_endpoints[resolver_endpoint_id]

    def get_resolver_rule(self, resolver_rule_id):
        """Return info for specified resolver rule."""
        self._validate_resolver_rule_id(resolver_rule_id)
        return self.resolver_rules[resolver_rule_id]

    def get_resolver_rule_association(self, resolver_rule_association_id):
        validate_args([("resolverRuleAssociationId", resolver_rule_association_id)])
        if resolver_rule_association_id not in self.resolver_rule_associations:
            raise ResourceNotFoundException(
                f"ResolverRuleAssociation '{resolver_rule_association_id}' does not Exist"
            )
        return self.resolver_rule_associations[resolver_rule_association_id]

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_endpoint_ip_addresses(self, resolver_endpoint_id):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        endpoint = self.resolver_endpoints[resolver_endpoint_id]
        return endpoint.ip_descriptions()

    @staticmethod
    def _add_field_name_to_filter(filters):
        """Convert both styles of filter names to lowercase snake format.

        "IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count".
        However, "HostVPCId" doesn't fit the pattern, so that's treated
        special.
        """
        for rr_filter in filters:
            filter_name = rr_filter["Name"]
            if "_" not in filter_name:
                if "Vpc" in filter_name:
                    filter_name = "WRONG"
                elif filter_name == "HostVPCId":
                    filter_name = "host_vpc_id"
                elif filter_name == "VPCId":
                    filter_name = "vpc_id"
                elif filter_name in ["Type", "TYPE"]:
                    filter_name = "rule_type"
                elif not filter_name.isupper():
                    filter_name = CAMEL_TO_SNAKE_PATTERN.sub("_", filter_name)
            rr_filter["Field"] = filter_name.lower()

    @staticmethod
    def _validate_filters(filters, allowed_filter_names):
        """Raise exception if filter names are not as expected."""
        for rr_filter in filters:
            if rr_filter["Field"] not in allowed_filter_names:
                raise InvalidParameterException(
                    f"The filter '{rr_filter['Name']}' is invalid"
                )
            if "Values" not in rr_filter:
                raise InvalidParameterException(
                    f"No values specified for filter {rr_filter['Name']}"
                )

    @staticmethod
    def _matches_all_filters(entity, filters):
        """Return True if this entity has fields matching all the filters."""
        for rr_filter in filters:
            field_value = getattr(entity, rr_filter["Field"])

            if isinstance(field_value, list):
                if not set(field_value).intersection(rr_filter["Values"]):
                    return False
            elif isinstance(field_value, int):
                if str(field_value) not in rr_filter["Values"]:
                    return False
            elif field_value not in rr_filter["Values"]:
                return False

        return True

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_endpoints(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverEndpoint.FILTER_NAMES)

        endpoints = []
        for endpoint in sorted(self.resolver_endpoints.values(), key=lambda x: x.name):
            if self._matches_all_filters(endpoint, filters):
                endpoints.append(endpoint)
        return endpoints

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_rules(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverRule.FILTER_NAMES)

        rules = []
        for rule in sorted(self.resolver_rules.values(), key=lambda x: x.name):
            if self._matches_all_filters(rule, filters):
                rules.append(rule)
        return rules

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_resolver_rule_associations(self, filters):
        if not filters:
            filters = []

        self._add_field_name_to_filter(filters)
        self._validate_filters(filters, ResolverRuleAssociation.FILTER_NAMES)

        rules = []
        for rule in sorted(
            self.resolver_rule_associations.values(), key=lambda x: x.name
        ):
            if self._matches_all_filters(rule, filters):
                rules.append(rule)
        return rules

    def _matched_arn(self, resource_arn):
        """Given ARN, raise exception if there is no corresponding resource."""
        for resolver_endpoint in self.resolver_endpoints.values():
            if resolver_endpoint.arn == resource_arn:
                return
        for resolver_rule in self.resolver_rules.values():
            if resolver_rule.arn == resource_arn:
                return
        raise ResourceNotFoundException(
            f"Resolver endpoint with ID '{resource_arn}' does not exist"
        )

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(self, resource_arn):
        self._matched_arn(resource_arn)
        return self.tagger.list_tags_for_resource(resource_arn).get("Tags")

    def tag_resource(self, resource_arn, tags):
        self._matched_arn(resource_arn)
        errmsg = self.tagger.validate_tags(
            tags, limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT
        )
        if errmsg:
            raise TagValidationException(errmsg)
        self.tagger.tag_resource(resource_arn, tags)

    def untag_resource(self, resource_arn, tag_keys):
        self._matched_arn(resource_arn)
        self.tagger.untag_resource_using_names(resource_arn, tag_keys)

    def update_resolver_endpoint(self, resolver_endpoint_id, name):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        validate_args([("name", name)])
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
        resolver_endpoint.update_name(name)
        return resolver_endpoint

    def associate_resolver_endpoint_ip_address(
        self, region, resolver_endpoint_id, ip_address
    ):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]

        if not ip_address.get("Ip"):
            subnet_info = ec2_backends[region].get_all_subnets(
                subnet_ids=[ip_address.get("SubnetId")]
            )[0]
            ip_address["Ip"] = subnet_info.get_available_subnet_ip(self)
        self._verify_subnet_ips(region, [ip_address], False)

        resolver_endpoint.associate_ip_address(ip_address)
        return resolver_endpoint

    def disassociate_resolver_endpoint_ip_address(
        self, resolver_endpoint_id, ip_address
    ):
        self._validate_resolver_endpoint_id(resolver_endpoint_id)
        resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]

        if not (ip_address.get("Ip") or ip_address.get("IpId")):
            raise InvalidRequestException(
                "[RSLVR-00503] Need to specify either the IP ID or both subnet and IP address in order to remove IP address."
            )

        resolver_endpoint.disassociate_ip_address(ip_address)
        return resolver_endpoint
Пример #15
0
class DAXBackend(BaseBackend):
    def __init__(self, region_name):
        self.region_name = region_name
        self._clusters = dict()
        self._tagger = TaggingService()

    @property
    def clusters(self):
        self._clusters = {
            name: cluster
            for name, cluster in self._clusters.items()
            if cluster.status != "deleted"
        }
        return self._clusters

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def create_cluster(
        self,
        cluster_name,
        node_type,
        description,
        replication_factor,
        availability_zones,
        subnet_group_name,
        security_group_ids,
        preferred_maintenance_window,
        notification_topic_arn,
        iam_role_arn,
        parameter_group_name,
        tags,
        sse_specification,
        cluster_endpoint_encryption_type,
    ):
        """
        The following parameters are not yet processed:
        AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName, ClusterEndpointEncryptionType
        """
        cluster = DaxCluster(
            region=self.region_name,
            name=cluster_name,
            description=description,
            node_type=node_type,
            replication_factor=replication_factor,
            iam_role_arn=iam_role_arn,
            sse_specification=sse_specification,
        )
        self.clusters[cluster_name] = cluster
        self._tagger.tag_resource(cluster.arn, tags)
        return cluster

    def delete_cluster(self, cluster_name):
        if cluster_name not in self.clusters:
            raise ClusterNotFoundFault()
        self.clusters[cluster_name].delete()
        return self.clusters[cluster_name]

    @paginate(PAGINATION_MODEL)
    def describe_clusters(self, cluster_names):
        clusters = self.clusters
        if not cluster_names:
            cluster_names = clusters.keys()

        for name in cluster_names:
            if name in self.clusters:
                self.clusters[name].advance()

        # Clusters may have been deleted while advancing the states
        clusters = self.clusters
        for name in cluster_names:
            if name not in self.clusters:
                raise ClusterNotFoundFault(name)
        return [
            cluster for name, cluster in clusters.items()
            if name in cluster_names
        ]

    def list_tags(self, resource_name):
        """
        Pagination is not yet implemented
        """
        # resource_name can be the name, or the full ARN
        name = resource_name.split("/")[-1]
        if name not in self.clusters:
            raise ClusterNotFoundFault()
        return self._tagger.list_tags_for_resource(self.clusters[name].arn)

    def increase_replication_factor(self, cluster_name, new_replication_factor,
                                    availability_zones):
        if cluster_name not in self.clusters:
            raise ClusterNotFoundFault()
        self.clusters[cluster_name].increase_replication_factor(
            new_replication_factor)
        return self.clusters[cluster_name]

    def decrease_replication_factor(
        self,
        cluster_name,
        new_replication_factor,
        availability_zones,
        node_ids_to_remove,
    ):
        """
        The AvailabilityZones-parameter is not yet implemented
        """
        if cluster_name not in self.clusters:
            raise ClusterNotFoundFault()
        self.clusters[cluster_name].decrease_replication_factor(
            new_replication_factor, node_ids_to_remove)
        return self.clusters[cluster_name]
Пример #16
0
class KmsBackend(BaseBackend):
    def __init__(self, region_name, account_id=None):
        super().__init__(region_name=region_name, account_id=account_id)
        self.keys = {}
        self.key_to_aliases = defaultdict(set)
        self.tagger = TaggingService(key_name="TagKey", value_name="TagValue")

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """Default VPC endpoint service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "kms")

    def _generate_default_keys(self, alias_name):
        """Creates default kms keys"""
        if alias_name in RESERVED_ALIASES:
            key = self.create_key(
                None,
                "ENCRYPT_DECRYPT",
                "SYMMETRIC_DEFAULT",
                "Default key",
                None,
                self.region_name,
            )
            self.add_alias(key.id, alias_name)
            return key.id

    def create_key(self,
                   policy,
                   key_usage,
                   key_spec,
                   description,
                   tags,
                   region,
                   multi_region=False):
        key = Key(policy, key_usage, key_spec, description, region,
                  multi_region)
        self.keys[key.id] = key
        if tags is not None and len(tags) > 0:
            self.tag_resource(key.id, tags)
        return key

    # https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html#mrk-sync-properties
    # In AWS replicas of a key only share some properties with the original key. Some of those properties get updated
    # in all replicas automatically if those properties change in the original key. Also, such properties can not be
    # changed for replicas directly.
    #
    # In our implementation with just create a copy of all the properties once without any protection from change,
    # as the exact implementation is currently infeasible.
    def replicate_key(self, key_id, replica_region):
        # Using copy() instead of deepcopy(), as the latter results in exception:
        #    TypeError: cannot pickle '_cffi_backend.FFI' object
        # Since we only update top level properties, copy() should suffice.
        replica_key = copy(self.keys[key_id])
        replica_key.region = replica_region
        to_region_backend = kms_backends[replica_region]
        to_region_backend.keys[replica_key.id] = replica_key

    def update_key_description(self, key_id, description):
        key = self.keys[self.get_key_id(key_id)]
        key.description = description

    def delete_key(self, key_id):
        if key_id in self.keys:
            if key_id in self.key_to_aliases:
                self.key_to_aliases.pop(key_id)
            self.tagger.delete_all_tags_for_resource(key_id)

            return self.keys.pop(key_id)

    def describe_key(self, key_id) -> Key:
        # allow the different methods (alias, ARN :key/, keyId, ARN alias) to
        # describe key not just KeyId
        key_id = self.get_key_id(key_id)
        if r"alias/" in str(key_id).lower():
            key_id = self.get_key_id_from_alias(key_id)
        return self.keys[self.get_key_id(key_id)]

    def list_keys(self):
        return self.keys.values()

    @staticmethod
    def get_key_id(key_id):
        # Allow use of ARN as well as pure KeyId
        if key_id.startswith("arn:") and ":key/" in key_id:
            return key_id.split(":key/")[1]

        return key_id

    @staticmethod
    def get_alias_name(alias_name):
        # Allow use of ARN as well as alias name
        if alias_name.startswith("arn:") and ":alias/" in alias_name:
            return "alias/" + alias_name.split(":alias/")[1]

        return alias_name

    def any_id_to_key_id(self, key_id):
        """Go from any valid key ID to the raw key ID.

        Acceptable inputs:
        - raw key ID
        - key ARN
        - alias name
        - alias ARN
        """
        key_id = self.get_alias_name(key_id)
        key_id = self.get_key_id(key_id)
        if key_id.startswith("alias/"):
            key_id = self.get_key_id_from_alias(key_id)
        return key_id

    def alias_exists(self, alias_name):
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                return True

        return False

    def add_alias(self, target_key_id, alias_name):
        self.key_to_aliases[target_key_id].add(alias_name)

    def delete_alias(self, alias_name):
        """Delete the alias."""
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                aliases.remove(alias_name)

    def get_all_aliases(self):
        return self.key_to_aliases

    def get_key_id_from_alias(self, alias_name):
        for key_id, aliases in dict(self.key_to_aliases).items():
            if alias_name in ",".join(aliases):
                return key_id
        if alias_name in RESERVED_ALIASES:
            key_id = self._generate_default_keys(alias_name)
            return key_id
        return None

    def enable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = True

    def disable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = False

    def get_key_rotation_status(self, key_id):
        return self.keys[self.get_key_id(key_id)].key_rotation_status

    def put_key_policy(self, key_id, policy):
        self.keys[self.get_key_id(key_id)].policy = policy

    def get_key_policy(self, key_id):
        return self.keys[self.get_key_id(key_id)].policy

    def disable_key(self, key_id):
        self.keys[key_id].enabled = False
        self.keys[key_id].key_state = "Disabled"

    def enable_key(self, key_id):
        self.keys[key_id].enabled = True
        self.keys[key_id].key_state = "Enabled"

    def cancel_key_deletion(self, key_id):
        self.keys[key_id].key_state = "Disabled"
        self.keys[key_id].deletion_date = None

    def schedule_key_deletion(self, key_id, pending_window_in_days):
        if 7 <= pending_window_in_days <= 30:
            self.keys[key_id].enabled = False
            self.keys[key_id].key_state = "PendingDeletion"
            self.keys[key_id].deletion_date = datetime.now() + timedelta(
                days=pending_window_in_days)
            return unix_time(self.keys[key_id].deletion_date)

    def encrypt(self, key_id, plaintext, encryption_context):
        key_id = self.any_id_to_key_id(key_id)

        ciphertext_blob = encrypt(
            master_keys=self.keys,
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return ciphertext_blob, arn

    def decrypt(self, ciphertext_blob, encryption_context):
        plaintext, key_id = decrypt(
            master_keys=self.keys,
            ciphertext_blob=ciphertext_blob,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return plaintext, arn

    def re_encrypt(
        self,
        ciphertext_blob,
        source_encryption_context,
        destination_key_id,
        destination_encryption_context,
    ):
        destination_key_id = self.any_id_to_key_id(destination_key_id)

        plaintext, decrypting_arn = self.decrypt(
            ciphertext_blob=ciphertext_blob,
            encryption_context=source_encryption_context,
        )
        new_ciphertext_blob, encrypting_arn = self.encrypt(
            key_id=destination_key_id,
            plaintext=plaintext,
            encryption_context=destination_encryption_context,
        )
        return new_ciphertext_blob, decrypting_arn, encrypting_arn

    def generate_data_key(self, key_id, encryption_context, number_of_bytes,
                          key_spec):
        key_id = self.any_id_to_key_id(key_id)

        if key_spec:
            # Note: Actual validation of key_spec is done in kms.responses
            if key_spec == "AES_128":
                plaintext_len = 16
            else:
                plaintext_len = 32
        else:
            plaintext_len = number_of_bytes

        plaintext = os.urandom(plaintext_len)

        ciphertext_blob, arn = self.encrypt(
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context)

        return plaintext, ciphertext_blob, arn

    def list_resource_tags(self, key_id_or_arn):
        key_id = self.get_key_id(key_id_or_arn)
        if key_id in self.keys:
            return self.tagger.list_tags_for_resource(key_id)
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def tag_resource(self, key_id_or_arn, tags):
        key_id = self.get_key_id(key_id_or_arn)
        if key_id in self.keys:
            self.tagger.tag_resource(key_id, tags)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def untag_resource(self, key_id_or_arn, tag_names):
        key_id = self.get_key_id(key_id_or_arn)
        if key_id in self.keys:
            self.tagger.untag_resource_using_names(key_id, tag_names)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def create_grant(
        self,
        key_id,
        grantee_principal,
        operations,
        name,
        constraints,
        retiring_principal,
    ):
        key = self.describe_key(key_id)
        grant = key.add_grant(
            name,
            grantee_principal,
            operations,
            constraints=constraints,
            retiring_principal=retiring_principal,
        )
        return grant.id, grant.token

    def list_grants(self, key_id, grant_id) -> [Grant]:
        key = self.describe_key(key_id)
        return key.list_grants(grant_id)

    def list_retirable_grants(self, retiring_principal):
        grants = []
        for key in self.keys.values():
            grants.extend(key.list_retirable_grants(retiring_principal))
        return grants

    def revoke_grant(self, key_id, grant_id) -> None:
        key = self.describe_key(key_id)
        key.revoke_grant(grant_id)

    def retire_grant(self, key_id, grant_id, grant_token) -> None:
        if grant_token:
            for key in self.keys.values():
                key.retire_grant_by_token(grant_token)
        else:
            key = self.describe_key(key_id)
            key.retire_grant(grant_id)

    def __ensure_valid_sign_and_verify_key(self, key: Key):
        if key.key_usage != "SIGN_VERIFY":
            raise ValidationException((
                "1 validation error detected: Value '{key_id}' at 'KeyId' failed "
                "to satisfy constraint: Member must point to a key with usage: 'SIGN_VERIFY'"
            ).format(key_id=key.id))

    def __ensure_valid_signing_augorithm(self, key: Key, signing_algorithm):
        if signing_algorithm not in key.signing_algorithms:
            raise ValidationException((
                "1 validation error detected: Value '{signing_algorithm}' at 'SigningAlgorithm' failed "
                "to satisfy constraint: Member must satisfy enum value set: "
                "{valid_sign_algorithms}").format(
                    signing_algorithm=signing_algorithm,
                    valid_sign_algorithms=key.signing_algorithms,
                ))

    def sign(self, key_id, message, signing_algorithm):
        """Sign message using generated private key.

        - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256

        - grant_tokens are not implemented
        """
        key = self.describe_key(key_id)

        self.__ensure_valid_sign_and_verify_key(key)
        self.__ensure_valid_signing_augorithm(key, signing_algorithm)

        # TODO: support more than one hardcoded algorithm based on KeySpec
        signature = key.private_key.sign(
            message,
            padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                        salt_length=padding.PSS.MAX_LENGTH),
            hashes.SHA256(),
        )

        return key.arn, signature, signing_algorithm

    def verify(self, key_id, message, signature, signing_algorithm):
        """Verify message using public key from generated private key.

        - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256

        - grant_tokens are not implemented
        """
        key = self.describe_key(key_id)

        self.__ensure_valid_sign_and_verify_key(key)
        self.__ensure_valid_signing_augorithm(key, signing_algorithm)

        if signing_algorithm not in key.signing_algorithms:
            raise ValidationException((
                "1 validation error detected: Value '{signing_algorithm}' at 'SigningAlgorithm' failed "
                "to satisfy constraint: Member must satisfy enum value set: "
                "{valid_sign_algorithms}").format(
                    signing_algorithm=signing_algorithm,
                    valid_sign_algorithms=key.signing_algorithms,
                ))

        public_key = key.private_key.public_key()

        try:
            # TODO: support more than one hardcoded algorithm based on KeySpec
            public_key.verify(
                signature,
                message,
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH,
                ),
                hashes.SHA256(),
            )
            return key.arn, True, signing_algorithm
        except InvalidSignature:
            return key.arn, False, signing_algorithm
Пример #17
0
class DirectoryServiceBackend(BaseBackend):
    """Implementation of DirectoryService APIs."""

    def __init__(self, region_name=None):
        self.region_name = region_name
        self.directories = {}
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """List of dicts representing default VPC endpoints for this service."""
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "ds"
        )

    @staticmethod
    def _validate_create_directory_args(
        name, passwd, size, vpc_settings, description, short_name,
    ):  # pylint: disable=too-many-arguments
        """Raise exception if create_directory() args don't meet constraints.

        The error messages are accumulated before the exception is raised.
        """
        error_tuples = []
        passwd_pattern = (
            r"(?=^.{8,64}$)((?=.*\d)(?=.*[A-Z])(?=.*[a-z])|"
            r"(?=.*\d)(?=.*[^A-Za-z0-9\s])(?=.*[a-z])|"
            r"(?=.*[^A-Za-z0-9\s])(?=.*[A-Z])(?=.*[a-z])|"
            r"(?=.*\d)(?=.*[A-Z])(?=.*[^A-Za-z0-9\s]))^.*"
        )
        if not re.match(passwd_pattern, passwd):
            # Can't have an odd number of backslashes in a literal.
            json_pattern = passwd_pattern.replace("\\", r"\\")
            error_tuples.append(
                (
                    "password",
                    passwd,
                    fr"satisfy regular expression pattern: {json_pattern}",
                )
            )

        if size.lower() not in ["small", "large"]:
            error_tuples.append(
                ("size", size, "satisfy enum value set: [Small, Large]")
            )

        name_pattern = r"^([a-zA-Z0-9]+[\\.-])+([a-zA-Z0-9])+$"
        if not re.match(name_pattern, name):
            error_tuples.append(
                ("name", name, fr"satisfy regular expression pattern: {name_pattern}")
            )

        subnet_id_pattern = r"^(subnet-[0-9a-f]{8}|subnet-[0-9a-f]{17})$"
        for subnet in vpc_settings["SubnetIds"]:
            if not re.match(subnet_id_pattern, subnet):
                error_tuples.append(
                    (
                        "vpcSettings.subnetIds",
                        subnet,
                        fr"satisfy regular expression pattern: {subnet_id_pattern}",
                    )
                )

        if description and len(description) > 128:
            error_tuples.append(
                ("description", description, "have length less than or equal to 128")
            )

        short_name_pattern = r'^[^\/:*?"<>|.]+[^\/:*?"<>|]*$'
        if short_name and not re.match(short_name_pattern, short_name):
            json_pattern = short_name_pattern.replace("\\", r"\\").replace('"', r"\"")
            error_tuples.append(
                (
                    "shortName",
                    short_name,
                    fr"satisfy regular expression pattern: {json_pattern}",
                )
            )

        if error_tuples:
            raise DsValidationException(error_tuples)

    @staticmethod
    def _validate_vpc_setting_values(region, vpc_settings):
        """Raise exception if vpc_settings are invalid.

        If settings are valid, add AvailabilityZones to vpc_settings.
        """
        if len(vpc_settings["SubnetIds"]) != 2:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones."
            )

        from moto.ec2 import ec2_backends  # pylint: disable=import-outside-toplevel

        # Subnet IDs are checked before the VPC ID.  The Subnet IDs must
        # be valid and in different availability zones.
        try:
            subnets = ec2_backends[region].get_all_subnets(
                subnet_ids=vpc_settings["SubnetIds"]
            )
        except InvalidSubnetIdError as exc:
            raise InvalidParameterException(
                "Invalid subnet ID(s). They must correspond to two subnets "
                "in different Availability Zones."
            ) from exc

        regions = [subnet.availability_zone for subnet in subnets]
        if regions[0] == regions[1]:
            raise ClientException(
                "Invalid subnet ID(s). The two subnets must be in "
                "different Availability Zones."
            )

        vpcs = ec2_backends[region].describe_vpcs()
        if vpc_settings["VpcId"] not in [x.id for x in vpcs]:
            raise ClientException("Invalid VPC ID.")

        vpc_settings["AvailabilityZones"] = regions

    def create_directory(
        self, region, name, short_name, password, description, size, vpc_settings, tags
    ):  # pylint: disable=too-many-arguments
        """Create a fake Simple Ad Directory."""
        if len(self.directories) > Directory.CLOUDONLY_DIRECTORIES_LIMIT:
            raise DirectoryLimitExceededException(
                f"Directory limit exceeded. A maximum of "
                f"{Directory.CLOUDONLY_DIRECTORIES_LIMIT} directories may be created"
            )

        # botocore doesn't look for missing vpc_settings, but boto3 does.
        if not vpc_settings:
            raise InvalidParameterException("VpcSettings must be specified.")

        self._validate_create_directory_args(
            name, password, size, vpc_settings, description, short_name,
        )
        self._validate_vpc_setting_values(region, vpc_settings)

        errmsg = self.tagger.validate_tags(tags or [])
        if errmsg:
            raise ValidationException(errmsg)

        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise DirectoryLimitExceededException("Tag Limit is exceeding")

        directory = Directory(
            name,
            password,
            size,
            vpc_settings,
            directory_type="SimpleAD",
            short_name=short_name,
            description=description,
        )
        self.directories[directory.directory_id] = directory
        self.tagger.tag_resource(directory.directory_id, tags or [])
        return directory.directory_id

    def _validate_directory_id(self, directory_id):
        """Raise an exception if the directory id is invalid or unknown."""
        # Validation of ID takes precedence over a check for its existence.
        id_pattern = r"^d-[0-9a-f]{10}$"
        if not re.match(id_pattern, directory_id):
            raise DsValidationException(
                [
                    (
                        "directoryId",
                        directory_id,
                        fr"satisfy regular expression pattern: {id_pattern}",
                    )
                ]
            )

        if directory_id not in self.directories:
            raise EntityDoesNotExistException(
                f"Directory {directory_id} does not exist"
            )

    def delete_directory(self, directory_id):
        """Delete directory with the matching ID."""
        self._validate_directory_id(directory_id)
        self.tagger.delete_all_tags_for_resource(directory_id)
        self.directories.pop(directory_id)
        return directory_id

    @paginate(pagination_model=PAGINATION_MODEL)
    def describe_directories(
        self, directory_ids=None, next_token=None, limit=0
    ):  # pylint: disable=unused-argument
        """Return info on all directories or directories with matching IDs."""
        for directory_id in directory_ids or self.directories:
            self._validate_directory_id(directory_id)

        directories = list(self.directories.values())
        if directory_ids:
            directories = [x for x in directories if x.directory_id in directory_ids]
        return sorted(directories, key=lambda x: x.launch_time)

    def get_directory_limits(self):
        """Return hard-coded limits for the directories."""
        counts = {"SimpleAD": 0, "MicrosoftAD": 0, "ConnectedAD": 0}
        for directory in self.directories.values():
            if directory.directory_type == "SimpleAD":
                counts["SimpleAD"] += 1
            elif directory.directory_type in ["MicrosoftAD", "SharedMicrosoftAD"]:
                counts["MicrosoftAD"] += 1
            elif directory.directory_type == "ADConnector":
                counts["ConnectedAD"] += 1

        return {
            "CloudOnlyDirectoriesLimit": Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyDirectoriesCurrentCount": counts["SimpleAD"],
            "CloudOnlyDirectoriesLimitReached": counts["SimpleAD"]
            == Directory.CLOUDONLY_DIRECTORIES_LIMIT,
            "CloudOnlyMicrosoftADLimit": Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "CloudOnlyMicrosoftADCurrentCount": counts["MicrosoftAD"],
            "CloudOnlyMicrosoftADLimitReached": counts["MicrosoftAD"]
            == Directory.CLOUDONLY_MICROSOFT_AD_LIMIT,
            "ConnectedDirectoriesLimit": Directory.CONNECTED_DIRECTORIES_LIMIT,
            "ConnectedDirectoriesCurrentCount": counts["ConnectedAD"],
            "ConnectedDirectoriesLimitReached": counts["ConnectedAD"]
            == Directory.CONNECTED_DIRECTORIES_LIMIT,
        }

    def add_tags_to_resource(self, resource_id, tags):
        """Add or overwrite one or more tags for specified directory."""
        self._validate_directory_id(resource_id)
        errmsg = self.tagger.validate_tags(tags)
        if errmsg:
            raise ValidationException(errmsg)
        if len(tags) > Directory.MAX_TAGS_PER_DIRECTORY:
            raise TagLimitExceededException("Tag limit exceeded")
        self.tagger.tag_resource(resource_id, tags)

    def remove_tags_from_resource(self, resource_id, tag_keys):
        """Removes tags from a directory."""
        self._validate_directory_id(resource_id)
        self.tagger.untag_resource_using_names(resource_id, tag_keys)

    @paginate(pagination_model=PAGINATION_MODEL)
    def list_tags_for_resource(
        self, resource_id, next_token=None, limit=None,
    ):  # pylint: disable=unused-argument
        """List all tags on a directory."""
        self._validate_directory_id(resource_id)
        return self.tagger.list_tags_for_resource(resource_id).get("Tags")
Пример #18
0
class EventsBackend(BaseBackend):
    ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$")
    STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$")

    def __init__(self, region_name):
        self.rules = {}
        # This array tracks the order in which the rules have been added, since
        # 2.6 doesn't have OrderedDicts.
        self.rules_order = []
        self.next_tokens = {}
        self.region_name = region_name
        self.event_buses = {}
        self.event_sources = {}
        self.tagger = TaggingService()

        self._add_default_event_bus()

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def _add_default_event_bus(self):
        self.event_buses["default"] = EventBus(self.region_name, "default")

    def _get_rule_by_index(self, i):
        return self.rules.get(self.rules_order[i])

    def _gen_next_token(self, index):
        token = os.urandom(128).encode("base64")
        self.next_tokens[token] = index
        return token

    def _process_token_and_limits(self,
                                  array_len,
                                  next_token=None,
                                  limit=None):
        start_index = 0
        end_index = array_len
        new_next_token = None

        if next_token:
            start_index = self.next_tokens.pop(next_token, 0)

        if limit is not None:
            new_end_index = start_index + int(limit)
            if new_end_index < end_index:
                end_index = new_end_index
                new_next_token = self._gen_next_token(end_index)

        return start_index, end_index, new_next_token

    def delete_rule(self, name):
        self.rules_order.pop(self.rules_order.index(name))
        arn = self.rules.get(name).arn
        if self.tagger.has_tags(arn):
            self.tagger.delete_all_tags_for_resource(arn)
        return self.rules.pop(name) is not None

    def describe_rule(self, name):
        return self.rules.get(name)

    def disable_rule(self, name):
        if name in self.rules:
            self.rules[name].disable()
            return True

        return False

    def enable_rule(self, name):
        if name in self.rules:
            self.rules[name].enable()
            return True

        return False

    def list_rule_names_by_target(self,
                                  target_arn,
                                  next_token=None,
                                  limit=None):
        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            for target in rule.targets:
                if target["Arn"] == target_arn:
                    matching_rules.append(rule.name)

        return_obj["RuleNames"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_rules(self, prefix=None, next_token=None, limit=None):
        match_string = ".*"
        if prefix is not None:
            match_string = "^" + prefix + match_string

        match_regex = re.compile(match_string)

        matching_rules = []
        return_obj = {}

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(self.rules), next_token, limit)

        for i in range(start_index, end_index):
            rule = self._get_rule_by_index(i)
            if match_regex.match(rule.name):
                matching_rules.append(rule)

        return_obj["Rules"] = matching_rules
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def list_targets_by_rule(self, rule, next_token=None, limit=None):
        # We'll let a KeyError exception be thrown for response to handle if
        # rule doesn't exist.
        rule = self.rules[rule]

        start_index, end_index, new_next_token = self._process_token_and_limits(
            len(rule.targets), next_token, limit)

        returned_targets = []
        return_obj = {}

        for i in range(start_index, end_index):
            returned_targets.append(rule.targets[i])

        return_obj["Targets"] = returned_targets
        if new_next_token is not None:
            return_obj["NextToken"] = new_next_token

        return return_obj

    def put_rule(self, name, **kwargs):
        new_rule = Rule(name, self.region_name, **kwargs)
        self.rules[new_rule.name] = new_rule
        self.rules_order.append(new_rule.name)
        return new_rule

    def put_targets(self, name, targets):
        rule = self.rules.get(name)

        if rule:
            rule.put_targets(targets)
            return True

        return False

    def put_events(self, events):
        num_events = len(events)

        if num_events < 1:
            raise JsonRESTError("ValidationError", "Need at least 1 event")
        elif num_events > 10:
            # the exact error text is longer, the Value list consists of all the put events
            raise ValidationException(
                "1 validation error detected: "
                "Value '[PutEventsRequestEntry]' at 'entries' failed to satisfy constraint: "
                "Member must have length less than or equal to 10")

        entries = []
        for event in events:
            if "Source" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Source is not valid. Reason: Source is a required argument.",
                })
            elif "DetailType" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter DetailType is not valid. Reason: DetailType is a required argument.",
                })
            elif "Detail" not in event:
                entries.append({
                    "ErrorCode":
                    "InvalidArgument",
                    "ErrorMessage":
                    "Parameter Detail is not valid. Reason: Detail is a required argument.",
                })
            else:
                try:
                    json.loads(event["Detail"])
                except ValueError:  # json.JSONDecodeError exists since Python 3.5
                    entries.append({
                        "ErrorCode": "MalformedDetail",
                        "ErrorMessage": "Detail is malformed.",
                    })
                    continue

                entries.append({"EventId": str(uuid4())})

        # We dont really need to store the events yet
        return entries

    def remove_targets(self, name, ids):
        rule = self.rules.get(name)

        if rule:
            rule.remove_targets(ids)
            return {"FailedEntries": [], "FailedEntryCount": 0}
        else:
            raise JsonRESTError(
                "ResourceNotFoundException",
                "An entity that you specified does not exist",
            )

    def test_event_pattern(self):
        raise NotImplementedError()

    def put_permission(self, event_bus_name, action, principal, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if action is None or action != "events:PutEvents":
            raise JsonRESTError(
                "ValidationException",
                "Provided value in parameter 'action' is not supported.",
            )

        if principal is None or self.ACCOUNT_ID.match(principal) is None:
            raise JsonRESTError("InvalidParameterValue",
                                r"Principal must match ^(\d{1,12}|\*)$")

        if statement_id is None or self.STATEMENT_ID.match(
                statement_id) is None:
            raise JsonRESTError(
                "InvalidParameterValue",
                r"StatementId must match ^[a-zA-Z0-9-_]{1,64}$")

        event_bus._permissions[statement_id] = {
            "Action": action,
            "Principal": principal,
        }

    def remove_permission(self, event_bus_name, statement_id):
        if not event_bus_name:
            event_bus_name = "default"

        event_bus = self.describe_event_bus(event_bus_name)

        if not len(event_bus._permissions):
            raise JsonRESTError("ResourceNotFoundException",
                                "EventBus does not have a policy.")

        if not event_bus._permissions.pop(statement_id, None):
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Statement with the provided id does not exist.",
            )

    def describe_event_bus(self, name):
        if not name:
            name = "default"

        event_bus = self.event_buses.get(name)

        if not event_bus:
            raise JsonRESTError("ResourceNotFoundException",
                                "Event bus {} does not exist.".format(name))

        return event_bus

    def create_event_bus(self, name, event_source_name=None):
        if name in self.event_buses:
            raise JsonRESTError(
                "ResourceAlreadyExistsException",
                "Event bus {} already exists.".format(name),
            )

        if not event_source_name and "/" in name:
            raise JsonRESTError("ValidationException",
                                "Event bus name must not contain '/'.")

        if event_source_name and event_source_name not in self.event_sources:
            raise JsonRESTError(
                "ResourceNotFoundException",
                "Event source {} does not exist.".format(event_source_name),
            )

        self.event_buses[name] = EventBus(self.region_name, name)

        return self.event_buses[name]

    def list_event_buses(self, name_prefix):
        if name_prefix:
            return [
                event_bus for event_bus in self.event_buses.values()
                if event_bus.name.startswith(name_prefix)
            ]

        return list(self.event_buses.values())

    def delete_event_bus(self, name):
        if name == "default":
            raise JsonRESTError("ValidationException",
                                "Cannot delete event bus default.")
        self.event_buses.pop(name, None)

    def list_tags_for_resource(self, arn):
        name = arn.split("/")[-1]
        if name in self.rules:
            return self.tagger.list_tags_for_resource(self.rules[name].arn)
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def tag_resource(self, arn, tags):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.tag_resource(self.rules[name].arn, tags)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))

    def untag_resource(self, arn, tag_names):
        name = arn.split("/")[-1]
        if name in self.rules:
            self.tagger.untag_resource_using_names(self.rules[name].arn,
                                                   tag_names)
            return {}
        raise ResourceNotFoundException(
            "Rule {0} does not exist on EventBus default.".format(name))
Пример #19
0
class CloudTrailBackend(BaseBackend):
    """Implementation of CloudTrail APIs."""
    def __init__(self, region_name):
        self.region_name = region_name
        self.trails = dict()
        self.tagging_service = TaggingService(tag_name="TagsList")

    def create_trail(
        self,
        name,
        bucket_name,
        s3_key_prefix,
        sns_topic_name,
        is_global,
        is_multi_region,
        log_validation,
        is_org_trail,
        cw_log_group_arn,
        cw_role_arn,
        kms_key_id,
        tags_list,
    ):
        trail = Trail(
            self.region_name,
            name,
            bucket_name,
            s3_key_prefix,
            sns_topic_name,
            is_global,
            is_multi_region,
            log_validation,
            is_org_trail,
            cw_log_group_arn,
            cw_role_arn,
            kms_key_id,
        )
        self.trails[name] = trail
        self.tagging_service.tag_resource(trail.arn, tags_list)
        return trail

    def get_trail(self, name_or_arn):
        if len(name_or_arn) < 3:
            raise TrailNameTooShort(actual_length=len(name_or_arn))
        if name_or_arn in self.trails:
            return self.trails[name_or_arn]
        for trail in self.trails.values():
            if trail.arn == name_or_arn:
                return trail
        raise TrailNotFoundException(name_or_arn)

    def get_trail_status(self, name):
        if len(name) < 3:
            raise TrailNameTooShort(actual_length=len(name))
        trail_name = next(
            (trail.trail_name for trail in self.trails.values()
             if trail.trail_name == name or trail.arn == name),
            None,
        )
        if not trail_name:
            # This particular method returns the ARN as part of the error message
            arn = (
                f"arn:aws:cloudtrail:{self.region_name}:{get_account_id()}:trail/{name}"
            )
            raise TrailNotFoundException(name=arn)
        trail = self.trails[trail_name]
        return trail.status

    def describe_trails(self, include_shadow_trails):
        all_trails = []
        if include_shadow_trails:
            for backend in cloudtrail_backends.values():
                all_trails.extend(backend.trails.values())
        else:
            all_trails.extend(self.trails.values())
        return all_trails

    def list_trails(self):
        return self.describe_trails(include_shadow_trails=True)

    def start_logging(self, name):
        trail = self.trails[name]
        trail.start_logging()

    def stop_logging(self, name):
        trail = self.trails[name]
        trail.stop_logging()

    def delete_trail(self, name):
        if name in self.trails:
            del self.trails[name]

    def update_trail(
        self,
        name,
        s3_bucket_name,
        s3_key_prefix,
        sns_topic_name,
        include_global_service_events,
        is_multi_region_trail,
        enable_log_file_validation,
        is_organization_trail,
        cw_log_group_arn,
        cw_role_arn,
        kms_key_id,
    ):
        trail = self.get_trail(name_or_arn=name)
        trail.update(
            s3_bucket_name=s3_bucket_name,
            s3_key_prefix=s3_key_prefix,
            sns_topic_name=sns_topic_name,
            include_global_service_events=include_global_service_events,
            is_multi_region_trail=is_multi_region_trail,
            enable_log_file_validation=enable_log_file_validation,
            is_organization_trail=is_organization_trail,
            cw_log_group_arn=cw_log_group_arn,
            cw_role_arn=cw_role_arn,
            kms_key_id=kms_key_id,
        )
        return trail

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def put_event_selectors(self, trail_name, event_selectors,
                            advanced_event_selectors):
        trail = self.get_trail(trail_name)
        trail.put_event_selectors(event_selectors, advanced_event_selectors)
        trail_arn = trail.arn
        return trail_arn, event_selectors, advanced_event_selectors

    def get_event_selectors(self, trail_name):
        trail = self.get_trail(trail_name)
        event_selectors, advanced_event_selectors = trail.get_event_selectors()
        return trail.arn, event_selectors, advanced_event_selectors

    def add_tags(self, resource_id, tags_list):
        self.tagging_service.tag_resource(resource_id, tags_list)

    def remove_tags(self, resource_id, tags_list):
        self.tagging_service.untag_resource_using_tags(resource_id, tags_list)

    def list_tags(self, resource_id_list):
        """
        Pagination is not yet implemented
        """
        resp = [{"ResourceId": r_id} for r_id in resource_id_list]
        for item in resp:
            item["TagsList"] = self.tagging_service.list_tags_for_resource(
                item["ResourceId"])["TagsList"]
        return resp

    def put_insight_selectors(self, trail_name, insight_selectors):
        trail = self.get_trail(trail_name)
        trail.put_insight_selectors(insight_selectors)
        return trail.arn, insight_selectors

    def get_insight_selectors(self, trail_name):
        trail = self.get_trail(trail_name)
        return trail.arn, trail.get_insight_selectors()
Пример #20
0
class ECRBackend(BaseBackend):
    def __init__(self, region_name):
        self.region_name = region_name
        self.registry_policy = None
        self.replication_config = {"rules": []}
        self.repositories: Dict[str, Repository] = {}
        self.tagger = TaggingService(tag_name="tags")

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    @staticmethod
    def default_vpc_endpoint_service(service_region, zones):
        """Default VPC endpoint service."""
        docker_endpoint = {
            "AcceptanceRequired":
            False,
            "AvailabilityZones":
            zones,
            "BaseEndpointDnsNames":
            [f"dkr.ecr.{service_region}.vpce.amazonaws.com"],
            "ManagesVpcEndpoints":
            False,
            "Owner":
            "amazon",
            "PrivateDnsName":
            f"*.dkr.ecr.{service_region}.amazonaws.com",
            "PrivateDnsNameVerificationState":
            "verified",
            "PrivateDnsNames": [{
                "PrivateDnsName":
                f"*.dkr.ecr.{service_region}.amazonaws.com"
            }],
            "ServiceId":
            f"vpce-svc-{BaseBackend.vpce_random_number()}",
            "ServiceName":
            f"com.amazonaws.{service_region}.ecr.dkr",
            "ServiceType": [{
                "ServiceType": "Interface"
            }],
            "Tags": [],
            "VpcEndpointPolicySupported":
            True,
        }
        return BaseBackend.default_vpc_endpoint_service_factory(
            service_region, zones, "api.ecr",
            special_service_name="ecr.api") + [docker_endpoint]

    def _get_repository(self, name, registry_id=None) -> Repository:
        repo = self.repositories.get(name)
        reg_id = registry_id or DEFAULT_REGISTRY_ID

        if not repo or repo.registry_id != reg_id:
            raise RepositoryNotFoundException(name, reg_id)
        return repo

    @staticmethod
    def _parse_resource_arn(resource_arn) -> EcrRepositoryArn:
        match = re.match(ECR_REPOSITORY_ARN_PATTERN, resource_arn)
        if not match:
            raise InvalidParameterException(
                "Invalid parameter at 'resourceArn' failed to satisfy constraint: "
                "'Invalid ARN'")
        return EcrRepositoryArn(**match.groupdict())

    def describe_repositories(self, registry_id=None, repository_names=None):
        """
        maxResults and nextToken not implemented
        """
        if repository_names:
            for repository_name in repository_names:
                if repository_name not in self.repositories:
                    raise RepositoryNotFoundException(
                        repository_name, registry_id or DEFAULT_REGISTRY_ID)

        repositories = []
        for repository in self.repositories.values():
            # If a registry_id was supplied, ensure this repository matches
            if registry_id:
                if repository.registry_id != registry_id:
                    continue
            # If a list of repository names was supplied, esure this repository
            # is in that list
            if repository_names:
                if repository.name not in repository_names:
                    continue
            repositories.append(repository.response_object)
        return repositories

    def create_repository(
        self,
        repository_name,
        registry_id,
        encryption_config,
        image_scan_config,
        image_tag_mutablility,
        tags,
    ):
        if self.repositories.get(repository_name):
            raise RepositoryAlreadyExistsException(repository_name,
                                                   DEFAULT_REGISTRY_ID)

        repository = Repository(
            region_name=self.region_name,
            repository_name=repository_name,
            registry_id=registry_id,
            encryption_config=encryption_config,
            image_scan_config=image_scan_config,
            image_tag_mutablility=image_tag_mutablility,
        )
        self.repositories[repository_name] = repository
        self.tagger.tag_resource(repository.arn, tags)

        return repository

    def delete_repository(self,
                          repository_name,
                          registry_id=None,
                          force=False):
        repo = self._get_repository(repository_name, registry_id)

        if repo.images and not force:
            raise RepositoryNotEmptyException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        self.tagger.delete_all_tags_for_resource(repo.arn)
        return self.repositories.pop(repository_name)

    def list_images(self, repository_name, registry_id=None):
        """
        maxResults and filtering not implemented
        """
        repository = None
        found = False
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
            if registry_id:
                if repository.registry_id == registry_id:
                    found = True
            else:
                found = True

        if not found:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        images = []
        for image in repository.images:
            images.append(image)
        return images

    def describe_images(self,
                        repository_name,
                        registry_id=None,
                        image_ids=None):
        repository = self._get_repository(repository_name, registry_id)

        if image_ids:
            response = set(
                repository._get_image(image_id.get("imageTag"),
                                      image_id.get("imageDigest"))
                for image_id in image_ids)

        else:
            response = []
            for image in repository.images:
                response.append(image)

        return response

    def put_image(self, repository_name, image_manifest, image_tag):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise Exception("{0} is not a repository".format(repository_name))

        existing_images = list(
            filter(
                lambda x: x.response_object["imageManifest"] == image_manifest,
                repository.images,
            ))
        if not existing_images:
            # this image is not in ECR yet
            image = Image(image_tag, image_manifest, repository_name)
            repository.images.append(image)
            return image
        else:
            # update existing image
            existing_images[0].update_tag(image_tag)
            return existing_images[0]

    def batch_get_image(self,
                        repository_name,
                        registry_id=None,
                        image_ids=None):
        """
        The parameter AcceptedMediaTypes has not yet been implemented
        """
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"images": [], "failures": []}

        for image_id in image_ids:
            found = False
            for image in repository.images:
                if ("imageDigest" in image_id
                        and image.get_image_digest() == image_id["imageDigest"]
                    ) or ("imageTag" in image_id
                          and image.image_tag == image_id["imageTag"]):
                    found = True
                    response["images"].append(image.response_batch_get_image)

        if not found:
            response["failures"].append({
                "imageId": {
                    "imageTag": image_id.get("imageTag", "null")
                },
                "failureCode":
                "ImageNotFound",
                "failureReason":
                "Requested image not found",
            })

        return response

    def batch_delete_image(self,
                           repository_name,
                           registry_id=None,
                           image_ids=None):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"imageIds": [], "failures": []}

        for image_id in image_ids:
            image_found = False

            # Is request missing both digest and tag?
            if "imageDigest" not in image_id and "imageTag" not in image_id:
                response["failures"].append({
                    "imageId": {},
                    "failureCode":
                    "MissingDigestAndTag",
                    "failureReason":
                    "Invalid request parameters: both tag and digest cannot be null",
                })
                continue

            # If we have a digest, is it valid?
            if "imageDigest" in image_id:
                pattern = re.compile(r"^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}")
                if not pattern.match(image_id.get("imageDigest")):
                    response["failures"].append({
                        "imageId": {
                            "imageDigest": image_id.get("imageDigest", "null")
                        },
                        "failureCode":
                        "InvalidImageDigest",
                        "failureReason":
                        "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'",
                    })
                    continue

            for num, image in enumerate(repository.images):

                # Search by matching both digest and tag
                if "imageDigest" in image_id and "imageTag" in image_id:
                    if (image_id["imageDigest"] == image.get_image_digest()
                            and image_id["imageTag"] in image.image_tags):
                        image_found = True
                        for image_tag in reversed(image.image_tags):
                            repository.images[num].image_tag = image_tag
                            response["imageIds"].append(
                                image.response_batch_delete_image)
                            repository.images[num].remove_tag(image_tag)
                        del repository.images[num]

                # Search by matching digest
                elif ("imageDigest" in image_id
                      and image.get_image_digest() == image_id["imageDigest"]):
                    image_found = True
                    for image_tag in reversed(image.image_tags):
                        repository.images[num].image_tag = image_tag
                        response["imageIds"].append(
                            image.response_batch_delete_image)
                        repository.images[num].remove_tag(image_tag)
                    del repository.images[num]

                # Search by matching tag
                elif ("imageTag" in image_id
                      and image_id["imageTag"] in image.image_tags):
                    image_found = True
                    repository.images[num].image_tag = image_id["imageTag"]
                    response["imageIds"].append(
                        image.response_batch_delete_image)
                    if len(image.image_tags) > 1:
                        repository.images[num].remove_tag(image_id["imageTag"])
                    else:
                        repository.images.remove(image)

            if not image_found:
                failure_response = {
                    "imageId": {},
                    "failureCode": "ImageNotFound",
                    "failureReason": "Requested image not found",
                }

                if "imageDigest" in image_id:
                    failure_response["imageId"]["imageDigest"] = image_id.get(
                        "imageDigest", "null")

                if "imageTag" in image_id:
                    failure_response["imageId"]["imageTag"] = image_id.get(
                        "imageTag", "null")

                response["failures"].append(failure_response)

        return response

    def list_tags_for_resource(self, arn):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)

        return self.tagger.list_tags_for_resource(repo.arn)

    def tag_resource(self, arn, tags):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.tag_resource(repo.arn, tags)

        return {}

    def untag_resource(self, arn, tag_keys):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.untag_resource_using_names(repo.arn, tag_keys)

        return {}

    def put_image_tag_mutability(self, registry_id, repository_name,
                                 image_tag_mutability):
        if image_tag_mutability not in ["IMMUTABLE", "MUTABLE"]:
            raise InvalidParameterException(
                "Invalid parameter at 'imageTagMutability' failed to satisfy constraint: "
                "'Member must satisfy enum value set: [IMMUTABLE, MUTABLE]'")

        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_tag_mutability=image_tag_mutability)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageTagMutability": repo.image_tag_mutability,
        }

    def put_image_scanning_configuration(self, registry_id, repository_name,
                                         image_scan_config):
        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_scan_config=image_scan_config)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageScanningConfiguration": repo.image_scanning_configuration,
        }

    def set_repository_policy(self, registry_id, repository_name, policy_text):
        repo = self._get_repository(repository_name, registry_id)

        try:
            iam_policy_document_validator = IAMPolicyDocumentValidator(
                policy_text)
            # the repository policy can be defined without a resource field
            iam_policy_document_validator._validate_resource_exist = lambda: None
            # the repository policy can have the old version 2008-10-17
            iam_policy_document_validator._validate_version = lambda: None
            iam_policy_document_validator.validate()
        except MalformedPolicyDocument:
            raise InvalidParameterException(
                "Invalid parameter at 'PolicyText' failed to satisfy constraint: "
                "'Invalid repository policy provided'")

        repo.policy = policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def get_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def delete_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.policy

        if not policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        repo.policy = None

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": policy,
        }

    def put_lifecycle_policy(self, registry_id, repository_name,
                             lifecycle_policy_text):
        repo = self._get_repository(repository_name, registry_id)

        validator = EcrLifecyclePolicyValidator(lifecycle_policy_text)
        validator.validate()

        repo.lifecycle_policy = lifecycle_policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "lifecyclePolicyText": repo.lifecycle_policy,
        }

    def get_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.lifecycle_policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            repo.lifecycle_policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }

    def delete_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.lifecycle_policy

        if not policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        repo.lifecycle_policy = None

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }

    def _validate_registry_policy_action(self, policy_text):
        # only CreateRepository & ReplicateImage actions are allowed
        VALID_ACTIONS = {"ecr:CreateRepository", "ecr:ReplicateImage"}

        policy = json.loads(policy_text)
        for statement in policy["Statement"]:
            action = statement["Action"]
            if isinstance(action, str):
                action = [action]
            if set(action) - VALID_ACTIONS:
                raise MalformedPolicyDocument()

    def put_registry_policy(self, policy_text):
        try:
            iam_policy_document_validator = IAMPolicyDocumentValidator(
                policy_text)
            iam_policy_document_validator.validate()

            self._validate_registry_policy_action(policy_text)
        except MalformedPolicyDocument:
            raise InvalidParameterException(
                "Invalid parameter at 'PolicyText' failed to satisfy constraint: "
                "'Invalid registry policy provided'")

        self.registry_policy = policy_text

        return {
            "registryId": get_account_id(),
            "policyText": policy_text,
        }

    def get_registry_policy(self):
        if not self.registry_policy:
            raise RegistryPolicyNotFoundException(get_account_id())

        return {
            "registryId": get_account_id(),
            "policyText": self.registry_policy,
        }

    def delete_registry_policy(self):
        policy = self.registry_policy
        if not policy:
            raise RegistryPolicyNotFoundException(get_account_id())

        self.registry_policy = None

        return {
            "registryId": get_account_id(),
            "policyText": policy,
        }

    def start_image_scan(self, registry_id, repository_name, image_id):
        repo = self._get_repository(repository_name, registry_id)

        image = repo._get_image(image_id.get("imageTag"),
                                image_id.get("imageDigest"))

        # scanning an image is only allowed once per day
        if image.last_scan and image.last_scan.date() == datetime.today().date(
        ):
            raise LimitExceededException()

        image.last_scan = datetime.today()

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageId": {
                "imageDigest": image.image_digest,
                "imageTag": image.image_tag,
            },
            "imageScanStatus": {
                "status": "IN_PROGRESS"
            },
        }

    def describe_image_scan_findings(self, registry_id, repository_name,
                                     image_id):
        repo = self._get_repository(repository_name, registry_id)

        image = repo._get_image(image_id.get("imageTag"),
                                image_id.get("imageDigest"))

        if not image.last_scan:
            image_id_rep = "{{imageDigest:'{0}', imageTag:'{1}'}}".format(
                image_id.get("imageDigest") or "null",
                image_id.get("imageTag") or "null",
            )
            raise ScanNotFoundException(
                image_id=image_id_rep,
                repository_name=repository_name,
                registry_id=repo.registry_id,
            )

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageId": {
                "imageDigest": image.image_digest,
                "imageTag": image.image_tag,
            },
            "imageScanStatus": {
                "status": "COMPLETE",
                "description": "The scan was completed successfully.",
            },
            "imageScanFindings": {
                "imageScanCompletedAt":
                iso_8601_datetime_without_milliseconds(image.last_scan),
                "vulnerabilitySourceUpdatedAt":
                iso_8601_datetime_without_milliseconds(datetime.utcnow()),
                "findings": [{
                    "name":
                    "CVE-9999-9999",
                    "uri":
                    "https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-9999-9999",
                    "severity":
                    "HIGH",
                    "attributes": [
                        {
                            "key": "package_version",
                            "value": "9.9.9"
                        },
                        {
                            "key": "package_name",
                            "value": "moto_fake"
                        },
                        {
                            "key": "CVSS2_VECTOR",
                            "value": "AV:N/AC:L/Au:N/C:P/I:P/A:P",
                        },
                        {
                            "key": "CVSS2_SCORE",
                            "value": "7.5"
                        },
                    ],
                }],
                "findingSeverityCounts": {
                    "HIGH": 1
                },
            },
        }

    def put_replication_configuration(self, replication_config):
        rules = replication_config["rules"]
        if len(rules) > 1:
            raise ValidationException("This feature is disabled")

        if len(rules) == 1:
            for dest in rules[0]["destinations"]:
                if (dest["region"] == self.region_name
                        and dest["registryId"] == DEFAULT_REGISTRY_ID):
                    raise InvalidParameterException(
                        "Invalid parameter at 'replicationConfiguration' failed to satisfy constraint: "
                        "'Replication destination cannot be the same as the source registry'"
                    )

        self.replication_config = replication_config

        return {"replicationConfiguration": replication_config}

    def describe_registry(self):
        return {
            "registryId": DEFAULT_REGISTRY_ID,
            "replicationConfiguration": self.replication_config,
        }
Пример #21
0
class CloudFrontBackend(BaseBackend):
    def __init__(self, region_name, account_id):
        super().__init__(region_name, account_id)
        self.distributions = dict()
        self.tagger = TaggingService()

        state_manager.register_default_transition(
            "cloudfront::distribution", transition={"progression": "manual", "times": 1}
        )

    def create_distribution(self, distribution_config, tags):
        """
        Not all configuration options are supported yet.  Please raise an issue if
        we're not persisting/returning the correct attributes for your
        use-case.
        """
        dist = Distribution(distribution_config)
        caller_reference = dist.distribution_config.caller_reference
        existing_dist = self._distribution_with_caller_reference(caller_reference)
        if existing_dist:
            raise DistributionAlreadyExists(existing_dist.distribution_id)
        self.distributions[dist.distribution_id] = dist
        self.tagger.tag_resource(dist.arn, tags)
        return dist, dist.location, dist.etag

    def get_distribution(self, distribution_id):
        if distribution_id not in self.distributions:
            raise NoSuchDistribution
        dist = self.distributions[distribution_id]
        dist.advance()
        return dist, dist.etag

    def delete_distribution(self, distribution_id, if_match):
        """
        The IfMatch-value is ignored - any value is considered valid.
        Calling this function without a value is invalid, per AWS' behaviour
        """
        if not if_match:
            raise InvalidIfMatchVersion
        if distribution_id not in self.distributions:
            raise NoSuchDistribution
        del self.distributions[distribution_id]

    def list_distributions(self):
        """
        Pagination is not supported yet.
        """
        for dist in self.distributions.values():
            dist.advance()
        return self.distributions.values()

    def _distribution_with_caller_reference(self, reference):
        for dist in self.distributions.values():
            config = dist.distribution_config
            if config.caller_reference == reference:
                return dist
        return False

    def update_distribution(self, DistributionConfig, Id, IfMatch):
        """
        The IfMatch-value is ignored - any value is considered valid.
        Calling this function without a value is invalid, per AWS' behaviour
        """
        if Id not in self.distributions or Id is None:
            raise NoSuchDistribution
        if not IfMatch:
            raise InvalidIfMatchVersion
        if not DistributionConfig:
            raise NoSuchDistribution
        dist = self.distributions[Id]

        aliases = DistributionConfig["Aliases"]["Items"]["CNAME"]
        dist.distribution_config.config = DistributionConfig
        dist.distribution_config.aliases = aliases
        self.distributions[Id] = dist
        dist.advance()
        return dist, dist.location, dist.etag

    def create_invalidation(self, dist_id, paths, caller_ref):
        dist, _ = self.get_distribution(dist_id)
        invalidation = Invalidation(dist, paths, caller_ref)

        return invalidation

    def list_tags_for_resource(self, resource):
        return self.tagger.list_tags_for_resource(resource)
Пример #22
0
class KmsBackend(BaseBackend):
    def __init__(self):
        self.keys = {}
        self.key_to_aliases = defaultdict(set)
        self.tagger = TaggingService(keyName="TagKey", valueName="TagValue")

    def create_key(self, policy, key_usage, customer_master_key_spec,
                   description, tags, region):
        key = Key(policy, key_usage, customer_master_key_spec, description,
                  region)
        self.keys[key.id] = key
        if tags is not None and len(tags) > 0:
            self.tag_resource(key.id, tags)
        return key

    def update_key_description(self, key_id, description):
        key = self.keys[self.get_key_id(key_id)]
        key.description = description

    def delete_key(self, key_id):
        if key_id in self.keys:
            if key_id in self.key_to_aliases:
                self.key_to_aliases.pop(key_id)
            self.tagger.delete_all_tags_for_resource(key_id)

            return self.keys.pop(key_id)

    def describe_key(self, key_id):
        # allow the different methods (alias, ARN :key/, keyId, ARN alias) to
        # describe key not just KeyId
        key_id = self.get_key_id(key_id)
        if r"alias/" in str(key_id).lower():
            key_id = self.get_key_id_from_alias(key_id.split("alias/")[1])
        return self.keys[self.get_key_id(key_id)]

    def list_keys(self):
        return self.keys.values()

    @staticmethod
    def get_key_id(key_id):
        # Allow use of ARN as well as pure KeyId
        if key_id.startswith("arn:") and ":key/" in key_id:
            return key_id.split(":key/")[1]

        return key_id

    @staticmethod
    def get_alias_name(alias_name):
        # Allow use of ARN as well as alias name
        if alias_name.startswith("arn:") and ":alias/" in alias_name:
            return alias_name.split(":alias/")[1]

        return alias_name

    def any_id_to_key_id(self, key_id):
        """Go from any valid key ID to the raw key ID.

        Acceptable inputs:
        - raw key ID
        - key ARN
        - alias name
        - alias ARN
        """
        key_id = self.get_alias_name(key_id)
        key_id = self.get_key_id(key_id)
        if key_id.startswith("alias/"):
            key_id = self.get_key_id_from_alias(key_id)
        return key_id

    def alias_exists(self, alias_name):
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                return True

        return False

    def add_alias(self, target_key_id, alias_name):
        self.key_to_aliases[target_key_id].add(alias_name)

    def delete_alias(self, alias_name):
        """Delete the alias."""
        for aliases in self.key_to_aliases.values():
            if alias_name in aliases:
                aliases.remove(alias_name)

    def get_all_aliases(self):
        return self.key_to_aliases

    def get_key_id_from_alias(self, alias_name):
        for key_id, aliases in dict(self.key_to_aliases).items():
            if alias_name in ",".join(aliases):
                return key_id
        return None

    def enable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = True

    def disable_key_rotation(self, key_id):
        self.keys[self.get_key_id(key_id)].key_rotation_status = False

    def get_key_rotation_status(self, key_id):
        return self.keys[self.get_key_id(key_id)].key_rotation_status

    def put_key_policy(self, key_id, policy):
        self.keys[self.get_key_id(key_id)].policy = policy

    def get_key_policy(self, key_id):
        return self.keys[self.get_key_id(key_id)].policy

    def disable_key(self, key_id):
        self.keys[key_id].enabled = False
        self.keys[key_id].key_state = "Disabled"

    def enable_key(self, key_id):
        self.keys[key_id].enabled = True
        self.keys[key_id].key_state = "Enabled"

    def cancel_key_deletion(self, key_id):
        self.keys[key_id].key_state = "Disabled"
        self.keys[key_id].deletion_date = None

    def schedule_key_deletion(self, key_id, pending_window_in_days):
        if 7 <= pending_window_in_days <= 30:
            self.keys[key_id].enabled = False
            self.keys[key_id].key_state = "PendingDeletion"
            self.keys[key_id].deletion_date = datetime.now() + timedelta(
                days=pending_window_in_days)
            return unix_time(self.keys[key_id].deletion_date)

    def encrypt(self, key_id, plaintext, encryption_context):
        key_id = self.any_id_to_key_id(key_id)

        ciphertext_blob = encrypt(
            master_keys=self.keys,
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return ciphertext_blob, arn

    def decrypt(self, ciphertext_blob, encryption_context):
        plaintext, key_id = decrypt(
            master_keys=self.keys,
            ciphertext_blob=ciphertext_blob,
            encryption_context=encryption_context,
        )
        arn = self.keys[key_id].arn
        return plaintext, arn

    def re_encrypt(
        self,
        ciphertext_blob,
        source_encryption_context,
        destination_key_id,
        destination_encryption_context,
    ):
        destination_key_id = self.any_id_to_key_id(destination_key_id)

        plaintext, decrypting_arn = self.decrypt(
            ciphertext_blob=ciphertext_blob,
            encryption_context=source_encryption_context,
        )
        new_ciphertext_blob, encrypting_arn = self.encrypt(
            key_id=destination_key_id,
            plaintext=plaintext,
            encryption_context=destination_encryption_context,
        )
        return new_ciphertext_blob, decrypting_arn, encrypting_arn

    def generate_data_key(self, key_id, encryption_context, number_of_bytes,
                          key_spec, grant_tokens):
        key_id = self.any_id_to_key_id(key_id)

        if key_spec:
            # Note: Actual validation of key_spec is done in kms.responses
            if key_spec == "AES_128":
                plaintext_len = 16
            else:
                plaintext_len = 32
        else:
            plaintext_len = number_of_bytes

        plaintext = os.urandom(plaintext_len)

        ciphertext_blob, arn = self.encrypt(
            key_id=key_id,
            plaintext=plaintext,
            encryption_context=encryption_context)

        return plaintext, ciphertext_blob, arn

    def list_resource_tags(self, key_id):
        if key_id in self.keys:
            return self.tagger.list_tags_for_resource(key_id)
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def tag_resource(self, key_id, tags):
        if key_id in self.keys:
            self.tagger.tag_resource(key_id, tags)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )

    def untag_resource(self, key_id, tag_names):
        if key_id in self.keys:
            self.tagger.untag_resource_using_names(key_id, tag_names)
            return {}
        raise JsonRESTError(
            "NotFoundException",
            "The request was rejected because the specified entity or resource could not be found.",
        )
Пример #23
0
class AppSyncBackend(BaseBackend):
    """Implementation of AppSync APIs."""
    def __init__(self, region_name=None):
        self.region_name = region_name
        self.graphql_apis = dict()
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def create_graphql_api(
        self,
        name,
        log_config,
        authentication_type,
        user_pool_config,
        open_id_connect_config,
        additional_authentication_providers,
        xray_enabled,
        lambda_authorizer_config,
        tags,
    ):
        graphql_api = GraphqlAPI(
            region=self.region_name,
            name=name,
            authentication_type=authentication_type,
            additional_authentication_providers=
            additional_authentication_providers,
            log_config=log_config,
            xray_enabled=xray_enabled,
            user_pool_config=user_pool_config,
            open_id_connect_config=open_id_connect_config,
            lambda_authorizer_config=lambda_authorizer_config,
        )
        self.graphql_apis[graphql_api.api_id] = graphql_api
        self.tagger.tag_resource(
            graphql_api.arn, TaggingService.convert_dict_to_tags_input(tags))
        return graphql_api

    def update_graphql_api(
        self,
        api_id,
        name,
        log_config,
        authentication_type,
        user_pool_config,
        open_id_connect_config,
        additional_authentication_providers,
        xray_enabled,
        lambda_authorizer_config,
    ):
        graphql_api = self.graphql_apis[api_id]
        graphql_api.update(
            name,
            additional_authentication_providers,
            authentication_type,
            lambda_authorizer_config,
            log_config,
            open_id_connect_config,
            user_pool_config,
            xray_enabled,
        )
        return graphql_api

    def get_graphql_api(self, api_id):
        if api_id not in self.graphql_apis:
            raise GraphqlAPINotFound(api_id)
        return self.graphql_apis[api_id]

    def delete_graphql_api(self, api_id):
        self.graphql_apis.pop(api_id)

    def list_graphql_apis(self):
        """
        Pagination or the maxResults-parameter have not yet been implemented.
        """
        return self.graphql_apis.values()

    def create_api_key(self, api_id, description, expires):
        return self.graphql_apis[api_id].create_api_key(description, expires)

    def delete_api_key(self, api_id, api_key_id):
        self.graphql_apis[api_id].delete_api_key(api_key_id)

    def list_api_keys(self, api_id):
        """
        Pagination or the maxResults-parameter have not yet been implemented.
        """
        if api_id in self.graphql_apis:
            return self.graphql_apis[api_id].list_api_keys()
        else:
            return []

    def update_api_key(self, api_id, api_key_id, description, expires):
        return self.graphql_apis[api_id].update_api_key(
            api_key_id, description, expires)

    def start_schema_creation(self, api_id, definition):
        self.graphql_apis[api_id].start_schema_creation(definition)
        return "PROCESSING"

    def get_schema_creation_status(self, api_id):
        return self.graphql_apis[api_id].get_schema_status()

    def tag_resource(self, resource_arn, tags):
        self.tagger.tag_resource(
            resource_arn, TaggingService.convert_dict_to_tags_input(tags))

    def untag_resource(self, resource_arn, tag_keys):
        self.tagger.untag_resource_using_names(resource_arn, tag_keys)

    def list_tags_for_resource(self, resource_arn):
        return self.tagger.get_tag_dict_for_resource(resource_arn)

    def get_type(self, api_id, type_name, type_format):
        return self.graphql_apis[api_id].get_type(type_name, type_format)
Пример #24
0
class MQBackend(BaseBackend):
    """
    No EC2 integration exists yet - subnet ID's and security group values are not validated. Default values may not exist.
    """

    def __init__(self, region_name=None):
        self.region_name = region_name
        self.brokers = dict()
        self.configs = dict()
        self.tagger = TaggingService()

    def reset(self):
        """Re-initialize all attributes for this instance."""
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def create_broker(
        self,
        authentication_strategy,
        auto_minor_version_upgrade,
        broker_name,
        configuration,
        deployment_mode,
        encryption_options,
        engine_type,
        engine_version,
        host_instance_type,
        ldap_server_metadata,
        logs,
        maintenance_window_start_time,
        publicly_accessible,
        security_groups,
        storage_type,
        subnet_ids,
        tags,
        users,
    ):
        broker = Broker(
            name=broker_name,
            region=self.region_name,
            authentication_strategy=authentication_strategy,
            auto_minor_version_upgrade=auto_minor_version_upgrade,
            configuration=configuration,
            deployment_mode=deployment_mode,
            encryption_options=encryption_options,
            engine_type=engine_type,
            engine_version=engine_version,
            host_instance_type=host_instance_type,
            ldap_server_metadata=ldap_server_metadata,
            logs=logs,
            maintenance_window_start_time=maintenance_window_start_time,
            publicly_accessible=publicly_accessible,
            security_groups=security_groups,
            storage_type=storage_type,
            subnet_ids=subnet_ids,
            users=users,
        )
        self.brokers[broker.id] = broker
        self.create_tags(broker.arn, tags)
        return broker.arn, broker.id

    def delete_broker(self, broker_id):
        del self.brokers[broker_id]

    def describe_broker(self, broker_id):
        if broker_id not in self.brokers:
            raise UnknownBroker(broker_id)
        return self.brokers[broker_id]

    def reboot_broker(self, broker_id):
        self.brokers[broker_id].reboot()

    def list_brokers(self):
        """
        Pagination is not yet implemented
        """
        return self.brokers.values()

    def create_user(self, broker_id, username, console_access, groups):
        broker = self.describe_broker(broker_id)
        broker.create_user(username, console_access, groups)

    def update_user(self, broker_id, console_access, groups, username):
        broker = self.describe_broker(broker_id)
        broker.update_user(username, console_access, groups)

    def describe_user(self, broker_id, username):
        broker = self.describe_broker(broker_id)
        return broker.get_user(username)

    def delete_user(self, broker_id, username):
        broker = self.describe_broker(broker_id)
        broker.delete_user(username)

    def list_users(self, broker_id):
        broker = self.describe_broker(broker_id)
        return broker.list_users()

    def create_configuration(self, name, engine_type, engine_version, tags):
        if engine_type.upper() == "RABBITMQ":
            raise UnsupportedEngineType(engine_type)
        if engine_type.upper() != "ACTIVEMQ":
            raise UnknownEngineType(engine_type)
        config = Configuration(
            region=self.region_name,
            name=name,
            engine_type=engine_type,
            engine_version=engine_version,
        )
        self.configs[config.id] = config
        self.tagger.tag_resource(
            config.arn, self.tagger.convert_dict_to_tags_input(tags)
        )
        return config

    def update_configuration(self, config_id, data, description):
        """
        No validation occurs on the provided XML. The authenticationStrategy may be changed depending on the provided configuration.
        """
        config = self.configs[config_id]
        config.update(data, description)
        return config

    def describe_configuration(self, config_id):
        if config_id not in self.configs:
            raise UnknownConfiguration(config_id)
        return self.configs[config_id]

    def describe_configuration_revision(self, config_id, revision_id):
        config = self.configs[config_id]
        return config.get_revision(revision_id)

    def list_configurations(self):
        """
        Pagination has not yet been implemented.
        """
        return self.configs.values()

    def create_tags(self, resource_arn, tags):
        self.tagger.tag_resource(
            resource_arn, self.tagger.convert_dict_to_tags_input(tags)
        )

    def list_tags(self, arn):
        return self.tagger.get_tag_dict_for_resource(arn)

    def delete_tags(self, resource_arn, tag_keys):
        if not isinstance(tag_keys, list):
            tag_keys = [tag_keys]
        self.tagger.untag_resource_using_names(resource_arn, tag_keys)

    def update_broker(
        self,
        authentication_strategy,
        auto_minor_version_upgrade,
        broker_id,
        configuration,
        engine_version,
        host_instance_type,
        ldap_server_metadata,
        logs,
        maintenance_window_start_time,
        security_groups,
    ):
        broker = self.describe_broker(broker_id)
        broker.update(
            authentication_strategy=authentication_strategy,
            auto_minor_version_upgrade=auto_minor_version_upgrade,
            configuration=configuration,
            engine_version=engine_version,
            host_instance_type=host_instance_type,
            ldap_server_metadata=ldap_server_metadata,
            logs=logs,
            maintenance_window_start_time=maintenance_window_start_time,
            security_groups=security_groups,
        )
Пример #25
0
class ECRBackend(BaseBackend):
    def __init__(self, region_name):
        self.region_name = region_name
        self.repositories: Dict[str, Repository] = {}
        self.tagger = TaggingService(tagName="tags")

    def reset(self):
        region_name = self.region_name
        self.__dict__ = {}
        self.__init__(region_name)

    def _get_repository(self, name, registry_id=None) -> Repository:
        repo = self.repositories.get(name)
        reg_id = registry_id or DEFAULT_REGISTRY_ID

        if not repo or repo.registry_id != reg_id:
            raise RepositoryNotFoundException(name, reg_id)
        return repo

    @staticmethod
    def _parse_resource_arn(resource_arn) -> EcrRepositoryArn:
        match = re.match(ECR_REPOSITORY_ARN_PATTERN, resource_arn)
        if not match:
            raise InvalidParameterException(
                "Invalid parameter at 'resourceArn' failed to satisfy constraint: "
                "'Invalid ARN'")
        return EcrRepositoryArn(**match.groupdict())

    def describe_repositories(self, registry_id=None, repository_names=None):
        """
        maxResults and nextToken not implemented
        """
        if repository_names:
            for repository_name in repository_names:
                if repository_name not in self.repositories:
                    raise RepositoryNotFoundException(
                        repository_name, registry_id or DEFAULT_REGISTRY_ID)

        repositories = []
        for repository in self.repositories.values():
            # If a registry_id was supplied, ensure this repository matches
            if registry_id:
                if repository.registry_id != registry_id:
                    continue
            # If a list of repository names was supplied, esure this repository
            # is in that list
            if repository_names:
                if repository.name not in repository_names:
                    continue
            repositories.append(repository.response_object)
        return repositories

    def create_repository(
        self,
        repository_name,
        encryption_config,
        image_scan_config,
        image_tag_mutablility,
        tags,
    ):
        if self.repositories.get(repository_name):
            raise RepositoryAlreadyExistsException(repository_name,
                                                   DEFAULT_REGISTRY_ID)

        repository = Repository(
            region_name=self.region_name,
            repository_name=repository_name,
            encryption_config=encryption_config,
            image_scan_config=image_scan_config,
            image_tag_mutablility=image_tag_mutablility,
        )
        self.repositories[repository_name] = repository
        self.tagger.tag_resource(repository.arn, tags)

        return repository

    def delete_repository(self, repository_name, registry_id=None):
        repo = self._get_repository(repository_name, registry_id)

        if repo.images:
            raise RepositoryNotEmptyException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        self.tagger.delete_all_tags_for_resource(repo.arn)
        return self.repositories.pop(repository_name)

    def list_images(self, repository_name, registry_id=None):
        """
        maxResults and filtering not implemented
        """
        repository = None
        found = False
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
            if registry_id:
                if repository.registry_id == registry_id:
                    found = True
            else:
                found = True

        if not found:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        images = []
        for image in repository.images:
            images.append(image)
        return images

    def describe_images(self,
                        repository_name,
                        registry_id=None,
                        image_ids=None):

        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if image_ids:
            response = set()
            for image_id in image_ids:
                found = False
                for image in repository.images:
                    if ("imageDigest" in image_id and image.get_image_digest()
                            == image_id["imageDigest"]) or (
                                "imageTag" in image_id
                                and image_id["imageTag"] in image.image_tags):
                        found = True
                        response.add(image)
                if not found:
                    image_id_representation = "{imageDigest:'%s', imageTag:'%s'}" % (
                        image_id.get("imageDigest", "null"),
                        image_id.get("imageTag", "null"),
                    )
                    raise ImageNotFoundException(
                        image_id=image_id_representation,
                        repository_name=repository_name,
                        registry_id=registry_id or DEFAULT_REGISTRY_ID,
                    )

        else:
            response = []
            for image in repository.images:
                response.append(image)

        return response

    def put_image(self, repository_name, image_manifest, image_tag):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise Exception("{0} is not a repository".format(repository_name))

        existing_images = list(
            filter(
                lambda x: x.response_object["imageManifest"] == image_manifest,
                repository.images,
            ))
        if not existing_images:
            # this image is not in ECR yet
            image = Image(image_tag, image_manifest, repository_name)
            repository.images.append(image)
            return image
        else:
            # update existing image
            existing_images[0].update_tag(image_tag)
            return existing_images[0]

    def batch_get_image(
        self,
        repository_name,
        registry_id=None,
        image_ids=None,
        accepted_media_types=None,
    ):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"images": [], "failures": []}

        for image_id in image_ids:
            found = False
            for image in repository.images:
                if ("imageDigest" in image_id
                        and image.get_image_digest() == image_id["imageDigest"]
                    ) or ("imageTag" in image_id
                          and image.image_tag == image_id["imageTag"]):
                    found = True
                    response["images"].append(image.response_batch_get_image)

        if not found:
            response["failures"].append({
                "imageId": {
                    "imageTag": image_id.get("imageTag", "null")
                },
                "failureCode":
                "ImageNotFound",
                "failureReason":
                "Requested image not found",
            })

        return response

    def batch_delete_image(self,
                           repository_name,
                           registry_id=None,
                           image_ids=None):
        if repository_name in self.repositories:
            repository = self.repositories[repository_name]
        else:
            raise RepositoryNotFoundException(
                repository_name, registry_id or DEFAULT_REGISTRY_ID)

        if not image_ids:
            raise ParamValidationError(
                msg='Missing required parameter in input: "imageIds"')

        response = {"imageIds": [], "failures": []}

        for image_id in image_ids:
            image_found = False

            # Is request missing both digest and tag?
            if "imageDigest" not in image_id and "imageTag" not in image_id:
                response["failures"].append({
                    "imageId": {},
                    "failureCode":
                    "MissingDigestAndTag",
                    "failureReason":
                    "Invalid request parameters: both tag and digest cannot be null",
                })
                continue

            # If we have a digest, is it valid?
            if "imageDigest" in image_id:
                pattern = re.compile(r"^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}")
                if not pattern.match(image_id.get("imageDigest")):
                    response["failures"].append({
                        "imageId": {
                            "imageDigest": image_id.get("imageDigest", "null")
                        },
                        "failureCode":
                        "InvalidImageDigest",
                        "failureReason":
                        "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'",
                    })
                    continue

            for num, image in enumerate(repository.images):

                # Search by matching both digest and tag
                if "imageDigest" in image_id and "imageTag" in image_id:
                    if (image_id["imageDigest"] == image.get_image_digest()
                            and image_id["imageTag"] in image.image_tags):
                        image_found = True
                        for image_tag in reversed(image.image_tags):
                            repository.images[num].image_tag = image_tag
                            response["imageIds"].append(
                                image.response_batch_delete_image)
                            repository.images[num].remove_tag(image_tag)
                        del repository.images[num]

                # Search by matching digest
                elif ("imageDigest" in image_id
                      and image.get_image_digest() == image_id["imageDigest"]):
                    image_found = True
                    for image_tag in reversed(image.image_tags):
                        repository.images[num].image_tag = image_tag
                        response["imageIds"].append(
                            image.response_batch_delete_image)
                        repository.images[num].remove_tag(image_tag)
                    del repository.images[num]

                # Search by matching tag
                elif ("imageTag" in image_id
                      and image_id["imageTag"] in image.image_tags):
                    image_found = True
                    repository.images[num].image_tag = image_id["imageTag"]
                    response["imageIds"].append(
                        image.response_batch_delete_image)
                    if len(image.image_tags) > 1:
                        repository.images[num].remove_tag(image_id["imageTag"])
                    else:
                        repository.images.remove(image)

                if not image_found:
                    failure_response = {
                        "imageId": {},
                        "failureCode": "ImageNotFound",
                        "failureReason": "Requested image not found",
                    }

                    if "imageDigest" in image_id:
                        failure_response["imageId"][
                            "imageDigest"] = image_id.get(
                                "imageDigest", "null")

                    if "imageTag" in image_id:
                        failure_response["imageId"]["imageTag"] = image_id.get(
                            "imageTag", "null")

                    response["failures"].append(failure_response)

        return response

    def list_tags_for_resource(self, arn):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)

        return self.tagger.list_tags_for_resource(repo.arn)

    def tag_resource(self, arn, tags):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.tag_resource(repo.arn, tags)

        return {}

    def untag_resource(self, arn, tag_keys):
        resource = self._parse_resource_arn(arn)
        repo = self._get_repository(resource.repo_name, resource.account_id)
        self.tagger.untag_resource_using_names(repo.arn, tag_keys)

        return {}

    def put_image_tag_mutability(self, registry_id, repository_name,
                                 image_tag_mutability):
        if image_tag_mutability not in ["IMMUTABLE", "MUTABLE"]:
            raise InvalidParameterException(
                "Invalid parameter at 'imageTagMutability' failed to satisfy constraint: "
                "'Member must satisfy enum value set: [IMMUTABLE, MUTABLE]'")

        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_tag_mutability=image_tag_mutability)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageTagMutability": repo.image_tag_mutability,
        }

    def put_image_scanning_configuration(self, registry_id, repository_name,
                                         image_scan_config):
        repo = self._get_repository(repository_name, registry_id)
        repo.update(image_scan_config=image_scan_config)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "imageScanningConfiguration": repo.image_scanning_configuration,
        }

    def set_repository_policy(self, registry_id, repository_name, policy_text):
        repo = self._get_repository(repository_name, registry_id)

        try:
            iam_policy_document_validator = IAMPolicyDocumentValidator(
                policy_text)
            # the repository policy can be defined without a resource field
            iam_policy_document_validator._validate_resource_exist = lambda: None
            # the repository policy can have the old version 2008-10-17
            iam_policy_document_validator._validate_version = lambda: None
            iam_policy_document_validator.validate()
        except MalformedPolicyDocument:
            raise InvalidParameterException(
                "Invalid parameter at 'PolicyText' failed to satisfy constraint: "
                "'Invalid repository policy provided'")

        repo.policy = policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def get_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": repo.policy,
        }

    def delete_repository_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.policy

        if not policy:
            raise RepositoryPolicyNotFoundException(repository_name,
                                                    repo.registry_id)

        repo.policy = None

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "policyText": policy,
        }

    def put_lifecycle_policy(self, registry_id, repository_name,
                             lifecycle_policy_text):
        repo = self._get_repository(repository_name, registry_id)

        validator = EcrLifecyclePolicyValidator(lifecycle_policy_text)
        validator.validate()

        repo.lifecycle_policy = lifecycle_policy_text

        return {
            "registryId": repo.registry_id,
            "repositoryName": repository_name,
            "lifecyclePolicyText": repo.lifecycle_policy,
        }

    def get_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)

        if not repo.lifecycle_policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            repo.lifecycle_policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }

    def delete_lifecycle_policy(self, registry_id, repository_name):
        repo = self._get_repository(repository_name, registry_id)
        policy = repo.lifecycle_policy

        if not policy:
            raise LifecyclePolicyNotFoundException(repository_name,
                                                   repo.registry_id)

        repo.lifecycle_policy = None

        return {
            "registryId":
            repo.registry_id,
            "repositoryName":
            repository_name,
            "lifecyclePolicyText":
            policy,
            "lastEvaluatedAt":
            iso_8601_datetime_without_milliseconds(datetime.utcnow()),
        }