def __init__(self, ec2_backend, elb_backend, elbv2_backend): self.autoscaling_groups = OrderedDict() self.launch_configurations = OrderedDict() self.policies = {} self.ec2_backend = ec2_backend self.elb_backend = elb_backend self.elbv2_backend = elbv2_backend
def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() self.subscriptions = OrderedDict() self.applications = {} self.platform_endpoints = {} self.region_name = region_name
class Shard(BaseModel): def __init__(self, shard_id, starting_hash, ending_hash): self._shard_id = shard_id self.starting_hash = starting_hash self.ending_hash = ending_hash self.records = OrderedDict() @property def shard_id(self): return "shardId-{0}".format(str(self._shard_id).zfill(12)) def get_records(self, last_sequence_id, limit): last_sequence_id = int(last_sequence_id) results = [] for sequence_number, record in self.records.items(): if sequence_number > last_sequence_id: results.append(record) last_sequence_id = sequence_number if len(results) == limit: break return results, last_sequence_id def put_record(self, partition_key, data, explicit_hash_key): # Note: this function is not safe for concurrency if self.records: last_sequence_number = self.get_max_sequence_number() else: last_sequence_number = 0 sequence_number = last_sequence_number + 1 self.records[sequence_number] = Record( partition_key, data, sequence_number, explicit_hash_key) return sequence_number def get_min_sequence_number(self): if self.records: return list(self.records.keys())[0] return 0 def get_max_sequence_number(self): if self.records: return list(self.records.keys())[-1] return 0 def to_json(self): return { "HashKeyRange": { "EndingHashKey": str(self.ending_hash), "StartingHashKey": str(self.starting_hash) }, "SequenceNumberRange": { "EndingSequenceNumber": self.get_max_sequence_number(), "StartingSequenceNumber": self.get_min_sequence_number(), }, "ShardId": self.shard_id }
class FakeTargetGroup(BaseModel): def __init__(self, name, arn, vpc_id, protocol, port, healthcheck_protocol, healthcheck_port, healthcheck_path, healthcheck_interval_seconds, healthcheck_timeout_seconds, healthy_threshold_count, unhealthy_threshold_count): self.name = name self.arn = arn self.vpc_id = vpc_id self.protocol = protocol self.port = port self.healthcheck_protocol = healthcheck_protocol self.healthcheck_port = healthcheck_port self.healthcheck_path = healthcheck_path self.healthcheck_interval_seconds = healthcheck_interval_seconds self.healthcheck_timeout_seconds = healthcheck_timeout_seconds self.healthy_threshold_count = healthy_threshold_count self.unhealthy_threshold_count = unhealthy_threshold_count self.load_balancer_arns = [] self.tags = {} self.attributes = { 'deregistration_delay.timeout_seconds': 300, 'stickiness.enabled': 'false', } self.targets = OrderedDict() def register(self, targets): for target in targets: self.targets[target['id']] = { 'id': target['id'], 'port': target.get('port', self.port), } def deregister(self, targets): for target in targets: t = self.targets.pop(target['id'], None) if not t: raise InvalidTargetError() def add_tag(self, key, value): if len(self.tags) >= 10 and key not in self.tags: raise TooManyTagsError() self.tags[key] = value def health_for(self, target): t = self.targets.get(target['id']) if t is None: raise InvalidTargetError() return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy')
def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() self.subscriptions = OrderedDict() self.applications = {} self.platform_endpoints = {} self.region_name = region_name self.sms_attributes = {} self.opt_out_numbers = ['+447420500600', '+447420505401', '+447632960543', '+447632960028', '+447700900149', '+447700900550', '+447700900545', '+447700900907'] self.permissions = {}
class SNSBackend(BaseBackend): def __init__(self): self.topics = OrderedDict() self.subscriptions = OrderedDict() def create_topic(self, name): topic = Topic(name, self) self.topics[topic.arn] = topic return topic def _get_values_nexttoken(self, values_map, next_token=None): if next_token is None: next_token = 0 next_token = int(next_token) values = list(values_map.values())[next_token: next_token + DEFAULT_PAGE_SIZE] if len(values) == DEFAULT_PAGE_SIZE: next_token = next_token + DEFAULT_PAGE_SIZE else: next_token = None return values, next_token def list_topics(self, next_token=None): return self._get_values_nexttoken(self.topics, next_token) def delete_topic(self, arn): self.topics.pop(arn) def get_topic(self, arn): return self.topics[arn] def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) setattr(topic, attribute_name, attribute_value) def subscribe(self, topic_arn, endpoint, protocol): topic = self.get_topic(topic_arn) subscription = Subscription(topic, endpoint, protocol) self.subscriptions[subscription.arn] = subscription return subscription def unsubscribe(self, subscription_arn): self.subscriptions.pop(subscription_arn) def list_subscriptions(self, topic_arn=None, next_token=None): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict([(k, sub) for k, sub in self.subscriptions.items() if sub.topic == topic]) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) def publish(self, topic_arn, message): topic = self.get_topic(topic_arn) message_id = topic.publish(message) return message_id
class Shard(object): def __init__(self, shard_id): self.shard_id = shard_id self.records = OrderedDict() def get_records(self, last_sequence_id, limit): last_sequence_id = int(last_sequence_id) results = [] for sequence_number, record in self.records.items(): if sequence_number > last_sequence_id: results.append(record) last_sequence_id = sequence_number if len(results) == limit: break return results, last_sequence_id def put_record(self, partition_key, data): # Note: this function is not safe for concurrency if self.records: last_sequence_number = self.get_max_sequence_number() else: last_sequence_number = 0 sequence_number = last_sequence_number + 1 self.records[sequence_number] = Record(partition_key, data, sequence_number) return sequence_number def get_min_sequence_number(self): if self.records: return list(self.records.keys())[0] return 0 def get_max_sequence_number(self): if self.records: return list(self.records.keys())[-1] return 0 def to_json(self): return { "HashKeyRange": { "EndingHashKey": "113427455640312821154458202477256070484", "StartingHashKey": "0" }, "SequenceNumberRange": { "EndingSequenceNumber": self.get_max_sequence_number(), "StartingSequenceNumber": self.get_min_sequence_number(), }, "ShardId": self.shard_id }
def __init__(self, ec2_backend, region_name): self.region = region_name self.clusters = {} self.subnet_groups = {} self.security_groups = { "Default": SecurityGroup("Default", "Default Redshift Security Group", self.region) } self.parameter_groups = { "default.redshift-1.0": ParameterGroup( "default.redshift-1.0", "redshift-1.0", "Default Redshift parameter group", self.region ) } self.ec2_backend = ec2_backend self.snapshots = OrderedDict() self.RESOURCE_TYPE_MAP = { 'cluster': self.clusters, 'parametergroup': self.parameter_groups, 'securitygroup': self.security_groups, 'snapshot': self.snapshots, 'subnetgroup': self.subnet_groups } self.snapshot_copy_grants = {}
def __init__(self, name, arn, vpc_id, protocol, port, healthcheck_protocol, healthcheck_port, healthcheck_path, healthcheck_interval_seconds, healthcheck_timeout_seconds, healthy_threshold_count, unhealthy_threshold_count): self.name = name self.arn = arn self.vpc_id = vpc_id self.protocol = protocol self.port = port self.healthcheck_protocol = healthcheck_protocol self.healthcheck_port = healthcheck_port self.healthcheck_path = healthcheck_path self.healthcheck_interval_seconds = healthcheck_interval_seconds self.healthcheck_timeout_seconds = healthcheck_timeout_seconds self.healthy_threshold_count = healthy_threshold_count self.unhealthy_threshold_count = unhealthy_threshold_count self.load_balancer_arns = [] self.tags = {} self.attributes = { 'deregistration_delay.timeout_seconds': 300, 'stickiness.enabled': 'false', } self.targets = OrderedDict()
def __init__(self, region, name, extended_config): self.region = region self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) self.name = name self.status = None self.extended_config = extended_config or {} self.creation_date = datetime.datetime.utcnow() self.last_modified_date = datetime.datetime.utcnow() self.clients = OrderedDict() self.identity_providers = OrderedDict() self.users = OrderedDict() self.refresh_tokens = {} self.access_tokens = {} self.id_tokens = {} with open(os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f: self.json_web_key = json.loads(f.read())
class DataPipelineBackend(BaseBackend): def __init__(self): self.pipelines = OrderedDict() def create_pipeline(self, name, unique_id, **kwargs): pipeline = Pipeline(name, unique_id, **kwargs) self.pipelines[pipeline.pipeline_id] = pipeline return pipeline def list_pipelines(self): return self.pipelines.values() def describe_pipelines(self, pipeline_ids): pipelines = [pipeline for pipeline in self.pipelines.values( ) if pipeline.pipeline_id in pipeline_ids] return pipelines def get_pipeline(self, pipeline_id): return self.pipelines[pipeline_id] def delete_pipeline(self, pipeline_id): self.pipelines.pop(pipeline_id, None) def put_pipeline_definition(self, pipeline_id, pipeline_objects): pipeline = self.get_pipeline(pipeline_id) pipeline.set_pipeline_objects(pipeline_objects) def get_pipeline_definition(self, pipeline_id): pipeline = self.get_pipeline(pipeline_id) return pipeline.objects def describe_objects(self, object_ids, pipeline_id): pipeline = self.get_pipeline(pipeline_id) pipeline_objects = [ pipeline_object for pipeline_object in pipeline.objects if pipeline_object.object_id in object_ids ] return pipeline_objects def activate_pipeline(self, pipeline_id): pipeline = self.get_pipeline(pipeline_id) pipeline.activate()
def __init__(self, name, arn, vpc_id, protocol, port, healthcheck_protocol=None, healthcheck_port=None, healthcheck_path=None, healthcheck_interval_seconds=None, healthcheck_timeout_seconds=None, healthy_threshold_count=None, unhealthy_threshold_count=None, matcher=None, target_type=None): # TODO: default values differs when you add Network Load balancer self.name = name self.arn = arn self.vpc_id = vpc_id self.protocol = protocol self.port = port self.healthcheck_protocol = healthcheck_protocol or 'HTTP' self.healthcheck_port = healthcheck_port or str(self.port) self.healthcheck_path = healthcheck_path or '/' self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30 self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5 self.healthy_threshold_count = healthy_threshold_count or 5 self.unhealthy_threshold_count = unhealthy_threshold_count or 2 self.load_balancer_arns = [] self.tags = {} if matcher is None: self.matcher = {'HttpCode': '200'} else: self.matcher = matcher self.target_type = target_type self.attributes = { 'deregistration_delay.timeout_seconds': 300, 'stickiness.enabled': 'false', } self.targets = OrderedDict()
def __init__(self, region, application, custom_user_data, token, attributes): self.region = region self.application = application self.custom_user_data = custom_user_data self.token = token self.attributes = attributes self.id = uuid.uuid4() self.messages = OrderedDict() self.__fixup_attributes()
def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() self.subscriptions = OrderedDict() self.applications = {} self.platform_endpoints = {} self.region_name = region_name self.sms_attributes = {} self.sms_messages = OrderedDict() self.opt_out_numbers = [ "+447420500600", "+447420505401", "+447632960543", "+447632960028", "+447700900149", "+447700900550", "+447700900545", "+447700900907", ]
def __init__(self, region, name, extended_config): self.region = region self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) self.name = name self.status = None self.extended_config = extended_config or {} self.creation_date = datetime.datetime.utcnow() self.last_modified_date = datetime.datetime.utcnow() self.clients = OrderedDict() self.identity_providers = OrderedDict() self.users = OrderedDict() self.refresh_tokens = {} self.access_tokens = {} self.id_tokens = {} with open( os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f: self.json_web_key = json.loads(f.read())
def __init__(self, name, security_groups, subnets, vpc_id, arn, dns_name, scheme='internet-facing'): self.name = name self.created_time = datetime.datetime.now() self.scheme = scheme self.security_groups = security_groups self.subnets = subnets or [] self.vpc_id = vpc_id self.listeners = OrderedDict() self.tags = {} self.arn = arn self.dns_name = dns_name
def __init__(self, name, arn, vpc_id, protocol, port, healthcheck_protocol, healthcheck_port, healthcheck_path, healthcheck_interval_seconds, healthcheck_timeout_seconds, healthy_threshold_count, unhealthy_threshold_count, matcher=None, target_type=None): self.name = name self.arn = arn self.vpc_id = vpc_id self.protocol = protocol self.port = port self.healthcheck_protocol = healthcheck_protocol self.healthcheck_port = healthcheck_port self.healthcheck_path = healthcheck_path self.healthcheck_interval_seconds = healthcheck_interval_seconds self.healthcheck_timeout_seconds = healthcheck_timeout_seconds self.healthy_threshold_count = healthy_threshold_count self.unhealthy_threshold_count = unhealthy_threshold_count self.load_balancer_arns = [] self.tags = {} self.matcher = matcher self.target_type = target_type self.attributes = { 'deregistration_delay.timeout_seconds': 300, 'stickiness.enabled': 'false', } self.targets = OrderedDict()
def __init__(self, ec2_backend, region_name): self.region = region_name self.clusters = {} self.subnet_groups = {} self.security_groups = { "Default": SecurityGroup("Default", "Default Redshift Security Group", self.region) } self.parameter_groups = { "default.redshift-1.0": ParameterGroup("default.redshift-1.0", "redshift-1.0", "Default Redshift parameter group", self.region) } self.ec2_backend = ec2_backend self.snapshots = OrderedDict() self.RESOURCE_TYPE_MAP = { 'cluster': self.clusters, 'parametergroup': self.parameter_groups, 'securitygroup': self.security_groups, 'snapshot': self.snapshots, 'subnetgroup': self.subnet_groups }
def __init__(self, name, arn, vpc_id, protocol, port, healthcheck_protocol=None, healthcheck_port=None, healthcheck_path=None, healthcheck_interval_seconds=None, healthcheck_timeout_seconds=None, healthy_threshold_count=None, unhealthy_threshold_count=None, matcher=None, target_type=None): # TODO: default values differs when you add Network Load balancer self.name = name self.arn = arn self.vpc_id = vpc_id self.protocol = protocol self.port = port self.healthcheck_protocol = healthcheck_protocol or 'HTTP' self.healthcheck_port = healthcheck_port or 'traffic-port' self.healthcheck_path = healthcheck_path or '/' self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30 self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5 self.healthy_threshold_count = healthy_threshold_count or 5 self.unhealthy_threshold_count = unhealthy_threshold_count or 2 self.load_balancer_arns = [] self.tags = {} if matcher is None: self.matcher = {'HttpCode': '200'} else: self.matcher = matcher self.target_type = target_type self.attributes = { 'deregistration_delay.timeout_seconds': 300, 'stickiness.enabled': 'false', } self.targets = OrderedDict()
def __init__( self, load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions, ): self.load_balancer_arn = load_balancer_arn self.arn = arn self.protocol = protocol.upper() self.port = port self.ssl_policy = ssl_policy self.certificate = certificate self.certificates = [certificate] if certificate is not None else [] self.default_actions = default_actions self.rules = OrderedDict()
def __init__(self, name, security_groups, subnets, vpc_id, arn, dns_name, scheme='internet-facing'): self.name = name self.created_time = datetime.datetime.now() self.scheme = scheme self.security_groups = security_groups self.subnets = subnets or [] self.vpc_id = vpc_id self.listeners = OrderedDict() self.tags = {} self.arn = arn self.dns_name = dns_name self.stack = 'ipv4' self.attrs = { 'access_logs.s3.enabled': 'false', 'access_logs.s3.bucket': None, 'access_logs.s3.prefix': None, 'deletion_protection.enabled': 'false', 'idle_timeout.timeout_seconds': '60' }
def transform(value, spec): """Apply transformations to make the output JSON comply with the expected form. This function applies: (1) Type cast to nodes with "type" property (e.g., 'true' to True). XML field values are all in text so this step is necessary to convert it to valid JSON objects. (2) Squashes "member" nodes to lists. """ if len(spec) == 1: return from_str(value, spec) od = OrderedDict() for k, v in value.items(): if k.startswith('@') or v is None: continue if spec[k]['type'] == 'list': if len(spec[k]['member']) == 1: if isinstance(v['member'], list): od[k] = transform(v['member'], spec[k]['member']) else: od[k] = [transform(v['member'], spec[k]['member'])] elif isinstance(v['member'], list): od[k] = [ transform(o, spec[k]['member']) for o in v['member'] ] elif isinstance(v['member'], OrderedDict): od[k] = [transform(v['member'], spec[k]['member'])] else: raise ValueError('Malformatted input') elif spec[k]['type'] == 'map': key = from_str(v['entry']['key'], spec[k]['key']) val = from_str(v['entry']['value'], spec[k]['value']) od[k] = {key: val} else: od[k] = transform(v, spec[k]) return od
def __init__(self): self.pipelines = OrderedDict()
class DynamoDBBackend(BaseBackend): def __init__(self): self.tables = OrderedDict() def create_table(self, name, **params): if name in self.tables: return None table = Table(name, **params) self.tables[name] = table return table def delete_table(self, name): return self.tables.pop(name, None) def update_table_throughput(self, name, throughput): table = self.tables[name] table.throughput = throughput return table def update_table_global_indexes(self, name, global_index_updates): table = self.tables[name] gsis_by_name = dict((i['IndexName'], i) for i in table.global_indexes) for gsi_update in global_index_updates: gsi_to_create = gsi_update.get('Create') gsi_to_update = gsi_update.get('Update') gsi_to_delete = gsi_update.get('Delete') if gsi_to_delete: index_name = gsi_to_delete['IndexName'] if index_name not in gsis_by_name: raise ValueError( 'Global Secondary Index does not exist, but tried to delete: %s' % gsi_to_delete['IndexName']) del gsis_by_name[index_name] if gsi_to_update: index_name = gsi_to_update['IndexName'] if index_name not in gsis_by_name: raise ValueError( 'Global Secondary Index does not exist, but tried to update: %s' % gsi_to_update['IndexName']) gsis_by_name[index_name].update(gsi_to_update) if gsi_to_create: if gsi_to_create['IndexName'] in gsis_by_name: raise ValueError( 'Global Secondary Index already exists: %s' % gsi_to_create['IndexName']) gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create table.global_indexes = gsis_by_name.values() return table def put_item(self, table_name, item_attrs, expected=None, overwrite=False): table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs, expected, overwrite) def get_table_keys_name(self, table_name, keys): """ Given a set of keys, extracts the key and range key """ table = self.tables.get(table_name) if not table: return None, None else: hash_key = range_key = None for key in keys: if key in table.hash_key_names: hash_key = key elif key in table.range_key_names: range_key = key return hash_key, range_key def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or ( table.has_range_key and table.range_key_attr not in keys): raise ValueError( "Table has a range key, but no range key was passed into get_item" ) hash_key = DynamoType(keys[table.hash_key_attr]) range_key = DynamoType( keys[table.range_key_attr]) if table.has_range_key else None return hash_key, range_key def get_table(self, table_name): return self.tables.get(table_name) def get_item(self, table_name, keys): table = self.get_table(table_name) if not table: raise ValueError("No table found") hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, index_name=None): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) range_values = [ DynamoType(range_value) for range_value in range_value_dicts ] return table.query(hash_key, range_comparison, range_values, index_name) def scan(self, table_name, filters): table = self.tables.get(table_name) if not table: return None, None, None scan_filters = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) return table.scan(scan_filters) def update_item(self, table_name, key, update_expression, attribute_updates): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): # Covers cases where table has hash and range keys, ``key`` param will be a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = DynamoType(key[table.range_key_attr]) elif table.hash_key_attr in key: # Covers tables that have a range key where ``key`` param is a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = None else: # Covers other cases hash_value = DynamoType(key) range_value = None item = table.get_item(hash_value, range_value) if update_expression: item.update(update_expression) else: item.update_with_attribute_updates(attribute_updates) return item def delete_item(self, table_name, keys): table = self.tables.get(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.delete_item(hash_key, range_key)
def __init__(self): self.tables = OrderedDict()
def __init__(self, region_name=None): self.region_name = region_name self.tables = OrderedDict()
class CloudFormationBackend(BaseBackend): def __init__(self): self.stacks = OrderedDict() self.deleted_stacks = {} self.exports = OrderedDict() self.change_sets = OrderedDict() def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False): stack_id = generate_stack_id(name) new_stack = FakeStack( stack_id=stack_id, name=name, template=template, parameters=parameters, region_name=region_name, notification_arns=notification_arns, tags=tags, role_arn=role_arn, cross_stack_resources=self.exports, create_change_set=create_change_set, ) self.stacks[stack_id] = new_stack self._validate_export_uniqueness(new_stack) for export in new_stack.exports: self.exports[export.name] = export return new_stack def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None): if change_set_type == 'UPDATE': stacks = self.stacks.values() stack = None for s in stacks: if s.name == stack_name: stack = s if stack is None: raise ValidationError(stack_name) else: stack = self.create_stack(stack_name, template, parameters, region_name, notification_arns, tags, role_arn, create_change_set=True) change_set_id = generate_changeset_id(change_set_name, region_name) self.stacks[change_set_name] = {'Id': change_set_id, 'StackId': stack.stack_id} self.change_sets[change_set_id] = stack return change_set_id, stack.stack_id def execute_change_set(self, change_set_name, stack_name=None): stack = None if change_set_name in self.change_sets: # This means arn was passed in stack = self.change_sets[change_set_name] else: for cs in self.change_sets: if self.change_sets[cs].name == change_set_name: stack = self.change_sets[cs] if stack is None: raise ValidationError(stack_name) if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS': stack._add_stack_event('CREATE_COMPLETE') else: stack._add_stack_event('UPDATE_IN_PROGRESS') stack._add_stack_event('UPDATE_COMPLETE') return True def describe_stacks(self, name_or_stack_id): stacks = self.stacks.values() if name_or_stack_id: for stack in stacks: if stack.name == name_or_stack_id or stack.stack_id == name_or_stack_id: return [stack] if self.deleted_stacks: deleted_stacks = self.deleted_stacks.values() for stack in deleted_stacks: if stack.stack_id == name_or_stack_id: return [stack] raise ValidationError(name_or_stack_id) else: return list(stacks) def list_stacks(self): return self.stacks.values() def get_stack(self, name_or_stack_id): all_stacks = dict(self.deleted_stacks, **self.stacks) if name_or_stack_id in all_stacks: # Lookup by stack id - deleted stacks incldued return all_stacks[name_or_stack_id] else: # Lookup by stack name - undeleted stacks only for stack in self.stacks.values(): if stack.name == name_or_stack_id: return stack def update_stack(self, name, template, role_arn=None, parameters=None, tags=None): stack = self.get_stack(name) stack.update(template, role_arn, parameters=parameters, tags=tags) return stack def list_stack_resources(self, stack_name_or_id): stack = self.get_stack(stack_name_or_id) return stack.stack_resources def delete_stack(self, name_or_stack_id): if name_or_stack_id in self.stacks: # Delete by stack id stack = self.stacks.pop(name_or_stack_id, None) stack.delete() self.deleted_stacks[stack.stack_id] = stack [self.exports.pop(export.name) for export in stack.exports] return self.stacks.pop(name_or_stack_id, None) else: # Delete by stack name for stack in list(self.stacks.values()): if stack.name == name_or_stack_id: self.delete_stack(stack.stack_id) def list_exports(self, token): all_exports = list(self.exports.values()) if token is None: exports = all_exports[0:100] next_token = '100' if len(all_exports) > 100 else None else: token = int(token) exports = all_exports[token:token + 100] next_token = str(token + 100) if len(all_exports) > token + 100 else None return exports, next_token def _validate_export_uniqueness(self, stack): new_stack_export_names = [x.name for x in stack.exports] export_names = self.exports.keys() if not set(export_names).isdisjoint(new_stack_export_names): raise ValidationError(stack.stack_id, message='Export names must be unique across a given region')
class RedshiftBackend(BaseBackend): def __init__(self, ec2_backend, region_name): self.region = region_name self.clusters = {} self.subnet_groups = {} self.security_groups = { "Default": SecurityGroup( "Default", "Default Redshift Security Group", self.region ) } self.parameter_groups = { "default.redshift-1.0": ParameterGroup( "default.redshift-1.0", "redshift-1.0", "Default Redshift parameter group", self.region, ) } self.ec2_backend = ec2_backend self.snapshots = OrderedDict() self.RESOURCE_TYPE_MAP = { "cluster": self.clusters, "parametergroup": self.parameter_groups, "securitygroup": self.security_groups, "snapshot": self.snapshots, "subnetgroup": self.subnet_groups, } self.snapshot_copy_grants = {} def reset(self): ec2_backend = self.ec2_backend region_name = self.region self.__dict__ = {} self.__init__(ec2_backend, region_name) def enable_snapshot_copy(self, **kwargs): cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] if not hasattr(cluster, "cluster_snapshot_copy_status"): if ( cluster.encrypted == "true" and kwargs["snapshot_copy_grant_name"] is None ): raise ClientError( "InvalidParameterValue", "SnapshotCopyGrantName is required for Snapshot Copy " "on KMS encrypted clusters.", ) status = { "DestinationRegion": kwargs["destination_region"], "RetentionPeriod": kwargs["retention_period"], "SnapshotCopyGrantName": kwargs["snapshot_copy_grant_name"], } cluster.cluster_snapshot_copy_status = status return cluster else: raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) def disable_snapshot_copy(self, **kwargs): cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] if hasattr(cluster, "cluster_snapshot_copy_status"): del cluster.cluster_snapshot_copy_status return cluster else: raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) def modify_snapshot_copy_retention_period( self, cluster_identifier, retention_period ): cluster = self.clusters[cluster_identifier] if hasattr(cluster, "cluster_snapshot_copy_status"): cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period return cluster else: raise SnapshotCopyDisabledFaultError(cluster_identifier) def create_cluster(self, **cluster_kwargs): cluster_identifier = cluster_kwargs["cluster_identifier"] cluster = Cluster(self, **cluster_kwargs) self.clusters[cluster_identifier] = cluster return cluster def describe_clusters(self, cluster_identifier=None): clusters = self.clusters.values() if cluster_identifier: if cluster_identifier in self.clusters: return [self.clusters[cluster_identifier]] else: raise ClusterNotFoundError(cluster_identifier) return clusters def modify_cluster(self, **cluster_kwargs): cluster_identifier = cluster_kwargs.pop("cluster_identifier") new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None) cluster = self.describe_clusters(cluster_identifier)[0] for key, value in cluster_kwargs.items(): setattr(cluster, key, value) if new_cluster_identifier: dic = { "cluster_identifier": cluster_identifier, "skip_final_snapshot": True, "final_cluster_snapshot_identifier": None, } self.delete_cluster(**dic) cluster.cluster_identifier = new_cluster_identifier self.clusters[new_cluster_identifier] = cluster return cluster def delete_cluster(self, **cluster_kwargs): cluster_identifier = cluster_kwargs.pop("cluster_identifier") cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot") cluster_snapshot_identifer = cluster_kwargs.pop( "final_cluster_snapshot_identifier" ) if cluster_identifier in self.clusters: if ( cluster_skip_final_snapshot is False and cluster_snapshot_identifer is None ): raise ClientError( "InvalidParameterValue", "FinalSnapshotIdentifier is required for Snapshot copy " "when SkipFinalSnapshot is False", ) elif ( cluster_skip_final_snapshot is False and cluster_snapshot_identifer is not None ): # create snapshot cluster = self.describe_clusters(cluster_identifier)[0] self.create_cluster_snapshot( cluster_identifier, cluster_snapshot_identifer, cluster.region, cluster.tags, ) return self.clusters.pop(cluster_identifier) raise ClusterNotFoundError(cluster_identifier) def create_cluster_subnet_group( self, cluster_subnet_group_name, description, subnet_ids, region_name, tags=None ): subnet_group = SubnetGroup( self.ec2_backend, cluster_subnet_group_name, description, subnet_ids, region_name, tags, ) self.subnet_groups[cluster_subnet_group_name] = subnet_group return subnet_group def describe_cluster_subnet_groups(self, subnet_identifier=None): subnet_groups = self.subnet_groups.values() if subnet_identifier: if subnet_identifier in self.subnet_groups: return [self.subnet_groups[subnet_identifier]] else: raise ClusterSubnetGroupNotFoundError(subnet_identifier) return subnet_groups def delete_cluster_subnet_group(self, subnet_identifier): if subnet_identifier in self.subnet_groups: return self.subnet_groups.pop(subnet_identifier) raise ClusterSubnetGroupNotFoundError(subnet_identifier) def create_cluster_security_group( self, cluster_security_group_name, description, region_name, tags=None ): security_group = SecurityGroup( cluster_security_group_name, description, region_name, tags ) self.security_groups[cluster_security_group_name] = security_group return security_group def describe_cluster_security_groups(self, security_group_name=None): security_groups = self.security_groups.values() if security_group_name: if security_group_name in self.security_groups: return [self.security_groups[security_group_name]] else: raise ClusterSecurityGroupNotFoundError(security_group_name) return security_groups def delete_cluster_security_group(self, security_group_identifier): if security_group_identifier in self.security_groups: return self.security_groups.pop(security_group_identifier) raise ClusterSecurityGroupNotFoundError(security_group_identifier) def create_cluster_parameter_group( self, cluster_parameter_group_name, group_family, description, region_name, tags=None, ): parameter_group = ParameterGroup( cluster_parameter_group_name, group_family, description, region_name, tags ) self.parameter_groups[cluster_parameter_group_name] = parameter_group return parameter_group def describe_cluster_parameter_groups(self, parameter_group_name=None): parameter_groups = self.parameter_groups.values() if parameter_group_name: if parameter_group_name in self.parameter_groups: return [self.parameter_groups[parameter_group_name]] else: raise ClusterParameterGroupNotFoundError(parameter_group_name) return parameter_groups def delete_cluster_parameter_group(self, parameter_group_name): if parameter_group_name in self.parameter_groups: return self.parameter_groups.pop(parameter_group_name) raise ClusterParameterGroupNotFoundError(parameter_group_name) def create_cluster_snapshot( self, cluster_identifier, snapshot_identifier, region_name, tags ): cluster = self.clusters.get(cluster_identifier) if not cluster: raise ClusterNotFoundError(cluster_identifier) if self.snapshots.get(snapshot_identifier) is not None: raise ClusterSnapshotAlreadyExistsError(snapshot_identifier) snapshot = Snapshot(cluster, snapshot_identifier, region_name, tags) self.snapshots[snapshot_identifier] = snapshot return snapshot def describe_cluster_snapshots( self, cluster_identifier=None, snapshot_identifier=None ): if cluster_identifier: cluster_snapshots = [] for snapshot in self.snapshots.values(): if snapshot.cluster.cluster_identifier == cluster_identifier: cluster_snapshots.append(snapshot) if cluster_snapshots: return cluster_snapshots raise ClusterNotFoundError(cluster_identifier) if snapshot_identifier: if snapshot_identifier in self.snapshots: return [self.snapshots[snapshot_identifier]] raise ClusterSnapshotNotFoundError(snapshot_identifier) return self.snapshots.values() def delete_cluster_snapshot(self, snapshot_identifier): if snapshot_identifier not in self.snapshots: raise ClusterSnapshotNotFoundError(snapshot_identifier) deleted_snapshot = self.snapshots.pop(snapshot_identifier) deleted_snapshot.status = "deleted" return deleted_snapshot def restore_from_cluster_snapshot(self, **kwargs): snapshot_identifier = kwargs.pop("snapshot_identifier") snapshot = self.describe_cluster_snapshots( snapshot_identifier=snapshot_identifier )[0] create_kwargs = { "node_type": snapshot.cluster.node_type, "master_username": snapshot.cluster.master_username, "master_user_password": snapshot.cluster.master_user_password, "db_name": snapshot.cluster.db_name, "cluster_type": "multi-node" if snapshot.cluster.number_of_nodes > 1 else "single-node", "availability_zone": snapshot.cluster.availability_zone, "port": snapshot.cluster.port, "cluster_version": snapshot.cluster.cluster_version, "number_of_nodes": snapshot.cluster.number_of_nodes, "encrypted": snapshot.cluster.encrypted, "tags": snapshot.cluster.tags, "restored_from_snapshot": True, "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing, } create_kwargs.update(kwargs) return self.create_cluster(**create_kwargs) def create_snapshot_copy_grant(self, **kwargs): snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] kms_key_id = kwargs["kms_key_id"] if snapshot_copy_grant_name not in self.snapshot_copy_grants: snapshot_copy_grant = SnapshotCopyGrant( snapshot_copy_grant_name, kms_key_id ) self.snapshot_copy_grants[snapshot_copy_grant_name] = snapshot_copy_grant return snapshot_copy_grant raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name) def delete_snapshot_copy_grant(self, **kwargs): snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name in self.snapshot_copy_grants: return self.snapshot_copy_grants.pop(snapshot_copy_grant_name) raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) def describe_snapshot_copy_grants(self, **kwargs): copy_grants = self.snapshot_copy_grants.values() snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name: if snapshot_copy_grant_name in self.snapshot_copy_grants: return [self.snapshot_copy_grants[snapshot_copy_grant_name]] else: raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) return copy_grants def _get_resource_from_arn(self, arn): try: arn_breakdown = arn.split(":") resource_type = arn_breakdown[5] if resource_type == "snapshot": resource_id = arn_breakdown[6].split("/")[1] else: resource_id = arn_breakdown[6] except IndexError: resource_type = resource_id = arn resources = self.RESOURCE_TYPE_MAP.get(resource_type) if resources is None: message = ( "Tagging is not supported for this type of resource: '{0}' " "(the ARN is potentially malformed, please check the ARN " "documentation for more information)".format(resource_type) ) raise ResourceNotFoundFaultError(message=message) try: resource = resources[resource_id] except KeyError: raise ResourceNotFoundFaultError(resource_type, resource_id) else: return resource @staticmethod def _describe_tags_for_resources(resources): tagged_resources = [] for resource in resources: for tag in resource.tags: data = { "ResourceName": resource.arn, "ResourceType": resource.resource_type, "Tag": {"Key": tag["Key"], "Value": tag["Value"]}, } tagged_resources.append(data) return tagged_resources def _describe_tags_for_resource_type(self, resource_type): resources = self.RESOURCE_TYPE_MAP.get(resource_type) if not resources: raise ResourceNotFoundFaultError(resource_type=resource_type) return self._describe_tags_for_resources(resources.values()) def _describe_tags_for_resource_name(self, resource_name): resource = self._get_resource_from_arn(resource_name) return self._describe_tags_for_resources([resource]) def create_tags(self, resource_name, tags): resource = self._get_resource_from_arn(resource_name) resource.create_tags(tags) def describe_tags(self, resource_name, resource_type): if resource_name and resource_type: raise InvalidParameterValueError( "You cannot filter a list of resources using an Amazon " "Resource Name (ARN) and a resource type together in the " "same request. Retry the request using either an ARN or " "a resource type, but not both." ) if resource_type: return self._describe_tags_for_resource_type(resource_type.lower()) if resource_name: return self._describe_tags_for_resource_name(resource_name) # If name and type are not specified, return all tagged resources. # TODO: Implement aws marker pagination tagged_resources = [] for resource_type in self.RESOURCE_TYPE_MAP: try: tagged_resources += self._describe_tags_for_resource_type(resource_type) except ResourceNotFoundFaultError: pass return tagged_resources def delete_tags(self, resource_name, tag_keys): resource = self._get_resource_from_arn(resource_name) resource.delete_tags(tag_keys)
def transform(value, spec): """Apply transformations to make the output JSON comply with the expected form. This function applies: (1) Type cast to nodes with "type" property (e.g., 'true' to True). XML field values are all in text so this step is necessary to convert it to valid JSON objects. (2) Squashes "member" nodes to lists. """ if len(spec) == 1: return from_str(value, spec) od = OrderedDict() for k, v in value.items(): if k.startswith('@'): continue if k not in spec: # this can happen when with an older version of # botocore for which the node in XML template is not # defined in service spec. log.warning( 'Field %s is not defined by the botocore version in use', k) continue if spec[k]['type'] == 'list': if v is None: od[k] = [] elif len(spec[k]['member']) == 1: if isinstance(v['member'], list): od[k] = transform(v['member'], spec[k]['member']) else: od[k] = [transform(v['member'], spec[k]['member'])] elif isinstance(v['member'], list): od[k] = [ transform(o, spec[k]['member']) for o in v['member'] ] elif isinstance(v['member'], OrderedDict): od[k] = [transform(v['member'], spec[k]['member'])] else: raise ValueError('Malformatted input') elif spec[k]['type'] == 'map': if v is None: od[k] = {} else: items = ([v['entry']] if not isinstance(v['entry'], list) else v['entry']) for item in items: key = from_str(item['key'], spec[k]['key']) val = from_str(item['value'], spec[k]['value']) if k not in od: od[k] = {} od[k][key] = val else: if v is None: od[k] = None else: od[k] = transform(v, spec[k]) return od
class SNSBackend(BaseBackend): def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() self.subscriptions = OrderedDict() self.applications = {} self.platform_endpoints = {} self.region_name = region_name self.sms_attributes = {} self.opt_out_numbers = ['+447420500600', '+447420505401', '+447632960543', '+447632960028', '+447700900149', '+447700900550', '+447700900545', '+447700900907'] self.permissions = {} def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def update_sms_attributes(self, attrs): self.sms_attributes.update(attrs) def create_topic(self, name): fails_constraints = not re.match(r'^[a-zA-Z0-9_-]{1,256}$', name) if fails_constraints: raise InvalidParameterValue("Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long.") candidate_topic = Topic(name, self) if candidate_topic.arn in self.topics: return self.topics[candidate_topic.arn] else: self.topics[candidate_topic.arn] = candidate_topic return candidate_topic def _get_values_nexttoken(self, values_map, next_token=None): if next_token is None: next_token = 0 next_token = int(next_token) values = list(values_map.values())[ next_token: next_token + DEFAULT_PAGE_SIZE] if len(values) == DEFAULT_PAGE_SIZE: next_token = next_token + DEFAULT_PAGE_SIZE else: next_token = None return values, next_token def _get_topic_subscriptions(self, topic): return [sub for sub in self.subscriptions.values() if sub.topic == topic] def list_topics(self, next_token=None): return self._get_values_nexttoken(self.topics, next_token) def delete_topic(self, arn): topic = self.get_topic(arn) subscriptions = self._get_topic_subscriptions(topic) for sub in subscriptions: self.unsubscribe(sub.arn) self.topics.pop(arn) def get_topic(self, arn): try: return self.topics[arn] except KeyError: raise SNSNotFoundError("Topic with arn {0} not found".format(arn)) def get_topic_from_phone_number(self, number): for subscription in self.subscriptions.values(): if subscription.protocol == 'sms' and subscription.endpoint == number: return subscription.topic.arn raise SNSNotFoundError('Could not find valid subscription') def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) setattr(topic, attribute_name, attribute_value) def subscribe(self, topic_arn, endpoint, protocol): # AWS doesn't create duplicates old_subscription = self._find_subscription(topic_arn, endpoint, protocol) if old_subscription: return old_subscription topic = self.get_topic(topic_arn) subscription = Subscription(topic, endpoint, protocol) self.subscriptions[subscription.arn] = subscription return subscription def _find_subscription(self, topic_arn, endpoint, protocol): for subscription in self.subscriptions.values(): if subscription.topic.arn == topic_arn and subscription.endpoint == endpoint and subscription.protocol == protocol: return subscription return None def unsubscribe(self, subscription_arn): self.subscriptions.pop(subscription_arn) def list_subscriptions(self, topic_arn=None, next_token=None): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict( [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)]) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) def publish(self, arn, message, subject=None, message_attributes=None): if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503 raise ValueError('Subject must be less than 100 characters') if len(message) > MAXIMUM_MESSAGE_LENGTH: raise InvalidParameterValue("An error occurred (InvalidParameter) when calling the Publish operation: Invalid parameter: Message too long") try: topic = self.get_topic(arn) message_id = topic.publish(message, subject=subject, message_attributes=message_attributes) except SNSNotFoundError: endpoint = self.get_endpoint(arn) message_id = endpoint.publish(message) return message_id def create_platform_application(self, region, name, platform, attributes): application = PlatformApplication(region, name, platform, attributes) self.applications[application.arn] = application return application def get_application(self, arn): try: return self.applications[arn] except KeyError: raise SNSNotFoundError( "Application with arn {0} not found".format(arn)) def set_application_attributes(self, arn, attributes): application = self.get_application(arn) application.attributes.update(attributes) return application def list_platform_applications(self): return self.applications.values() def delete_platform_application(self, platform_arn): self.applications.pop(platform_arn) def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): if any(token == endpoint.token for endpoint in self.platform_endpoints.values()): raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) platform_endpoint = PlatformEndpoint( region, application, custom_user_data, token, attributes) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint def list_endpoints_by_platform_application(self, application_arn): return [ endpoint for endpoint in self.platform_endpoints.values() if endpoint.application.arn == application_arn ] def get_endpoint(self, arn): try: return self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError( "Endpoint with arn {0} not found".format(arn)) def set_endpoint_attributes(self, arn, attributes): endpoint = self.get_endpoint(arn) endpoint.attributes.update(attributes) return endpoint def delete_endpoint(self, arn): try: del self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError( "Endpoint with arn {0} not found".format(arn)) def get_subscription_attributes(self, arn): _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] if not _subscription: raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) subscription = _subscription[0] return subscription.attributes def set_subscription_attributes(self, arn, name, value): if name not in ['RawMessageDelivery', 'DeliveryPolicy', 'FilterPolicy']: raise SNSInvalidParameter('AttributeName') # TODO: should do validation _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] if not _subscription: raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) subscription = _subscription[0] subscription.attributes[name] = value if name == 'FilterPolicy': subscription._filter_policy = json.loads(value)
def __init__(self, region): super(CognitoIdpBackend, self).__init__() self.region = region self.user_pools = OrderedDict() self.user_pool_domains = OrderedDict() self.sessions = {}
def __init__(self, region_name=None): self.region_name = region_name self.load_balancers = OrderedDict()
class FakeTargetGroup(BaseModel): def __init__(self, name, arn, vpc_id, protocol, port, healthcheck_protocol, healthcheck_port, healthcheck_path, healthcheck_interval_seconds, healthcheck_timeout_seconds, healthy_threshold_count, unhealthy_threshold_count, matcher=None, target_type=None): self.name = name self.arn = arn self.vpc_id = vpc_id self.protocol = protocol self.port = port self.healthcheck_protocol = healthcheck_protocol self.healthcheck_port = healthcheck_port self.healthcheck_path = healthcheck_path self.healthcheck_interval_seconds = healthcheck_interval_seconds self.healthcheck_timeout_seconds = healthcheck_timeout_seconds self.healthy_threshold_count = healthy_threshold_count self.unhealthy_threshold_count = unhealthy_threshold_count self.load_balancer_arns = [] self.tags = {} self.matcher = matcher self.target_type = target_type self.attributes = { 'deregistration_delay.timeout_seconds': 300, 'stickiness.enabled': 'false', } self.targets = OrderedDict() @property def physical_resource_id(self): return self.arn def register(self, targets): for target in targets: self.targets[target['id']] = { 'id': target['id'], 'port': target.get('port', self.port), } def deregister(self, targets): for target in targets: t = self.targets.pop(target['id'], None) if not t: raise InvalidTargetError() def add_tag(self, key, value): if len(self.tags) >= 10 and key not in self.tags: raise TooManyTagsError() self.tags[key] = value def health_for(self, target): t = self.targets.get(target['id']) if t is None: raise InvalidTargetError() return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy') @classmethod def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): properties = cloudformation_json['Properties'] elbv2_backend = elbv2_backends[region_name] # per cloudformation docs: # The target group name should be shorter than 22 characters because # AWS CloudFormation uses the target group name to create the name of the load balancer. name = properties.get('Name', resource_name[:22]) vpc_id = properties.get("VpcId") protocol = properties.get('Protocol') port = properties.get("Port") healthcheck_protocol = properties.get("HealthCheckProtocol") healthcheck_port = properties.get("HealthCheckPort") healthcheck_path = properties.get("HealthCheckPath") healthcheck_interval_seconds = properties.get( "HealthCheckIntervalSeconds") healthcheck_timeout_seconds = properties.get( "HealthCheckTimeoutSeconds") healthy_threshold_count = properties.get("HealthyThresholdCount") unhealthy_threshold_count = properties.get("UnhealthyThresholdCount") matcher = properties.get("Matcher") target_type = properties.get("TargetType") target_group = elbv2_backend.create_target_group( name=name, vpc_id=vpc_id, protocol=protocol, port=port, healthcheck_protocol=healthcheck_protocol, healthcheck_port=healthcheck_port, healthcheck_path=healthcheck_path, healthcheck_interval_seconds=healthcheck_interval_seconds, healthcheck_timeout_seconds=healthcheck_timeout_seconds, healthy_threshold_count=healthy_threshold_count, unhealthy_threshold_count=unhealthy_threshold_count, matcher=matcher, target_type=target_type, ) return target_group
def __init__(self): self.stacks = OrderedDict() self.deleted_stacks = {} self.exports = OrderedDict() self.change_sets = OrderedDict()
class DynamoDBBackend(BaseBackend): def __init__(self): self.tables = OrderedDict() def create_table(self, name, **params): table = Table(name, **params) self.tables[name] = table return table def delete_table(self, name): return self.tables.pop(name, None) def update_table_throughput(self, name, new_read_units, new_write_units): table = self.tables[name] table.read_capacity = new_read_units table.write_capacity = new_write_units return table def put_item(self, table_name, item_attrs): table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs) def get_item(self, table_name, hash_key_dict, range_key_dict): table = self.tables.get(table_name) if not table: return None hash_key = DynamoType(hash_key_dict) range_key = DynamoType(range_key_dict) if range_key_dict else None return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) range_values = [ DynamoType(range_value) for range_value in range_value_dicts ] return table.query(hash_key, range_comparison, range_values) def scan(self, table_name, filters): table = self.tables.get(table_name) if not table: return None, None, None scan_filters = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) return table.scan(scan_filters) def delete_item(self, table_name, hash_key_dict, range_key_dict): table = self.tables.get(table_name) if not table: return None hash_key = DynamoType(hash_key_dict) range_key = DynamoType(range_key_dict) if range_key_dict else None return table.delete_item(hash_key, range_key)
class SNSBackend(BaseBackend): def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() self.subscriptions = OrderedDict() self.applications = {} self.platform_endpoints = {} self.region_name = region_name def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def create_topic(self, name): topic = Topic(name, self) self.topics[topic.arn] = topic return topic def _get_values_nexttoken(self, values_map, next_token=None): if next_token is None: next_token = 0 next_token = int(next_token) values = list(values_map.values())[next_token:next_token + DEFAULT_PAGE_SIZE] if len(values) == DEFAULT_PAGE_SIZE: next_token = next_token + DEFAULT_PAGE_SIZE else: next_token = None return values, next_token def _get_topic_subscriptions(self, topic): return [ sub for sub in self.subscriptions.values() if sub.topic == topic ] def list_topics(self, next_token=None): return self._get_values_nexttoken(self.topics, next_token) def delete_topic(self, arn): topic = self.get_topic(arn) subscriptions = self._get_topic_subscriptions(topic) for sub in subscriptions: self.unsubscribe(sub.arn) self.topics.pop(arn) def get_topic(self, arn): try: return self.topics[arn] except KeyError: raise SNSNotFoundError("Topic with arn {0} not found".format(arn)) def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) setattr(topic, attribute_name, attribute_value) def subscribe(self, topic_arn, endpoint, protocol): topic = self.get_topic(topic_arn) subscription = Subscription(topic, endpoint, protocol) self.subscriptions[subscription.arn] = subscription return subscription def unsubscribe(self, subscription_arn): self.subscriptions.pop(subscription_arn) def list_subscriptions(self, topic_arn=None, next_token=None): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict([ (sub.arn, sub) for sub in self._get_topic_subscriptions(topic) ]) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) def publish(self, arn, message): try: topic = self.get_topic(arn) message_id = topic.publish(message) except SNSNotFoundError: endpoint = self.get_endpoint(arn) message_id = endpoint.publish(message) return message_id def create_platform_application(self, region, name, platform, attributes): application = PlatformApplication(region, name, platform, attributes) self.applications[application.arn] = application return application def get_application(self, arn): try: return self.applications[arn] except KeyError: raise SNSNotFoundError( "Application with arn {0} not found".format(arn)) def set_application_attributes(self, arn, attributes): application = self.get_application(arn) application.attributes.update(attributes) return application def list_platform_applications(self): return self.applications.values() def delete_platform_application(self, platform_arn): self.applications.pop(platform_arn) def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): if any(token == endpoint.token for endpoint in self.platform_endpoints.values()): raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) platform_endpoint = PlatformEndpoint(region, application, custom_user_data, token, attributes) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint def list_endpoints_by_platform_application(self, application_arn): return [ endpoint for endpoint in self.platform_endpoints.values() if endpoint.application.arn == application_arn ] def get_endpoint(self, arn): try: return self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError( "Endpoint with arn {0} not found".format(arn)) def set_endpoint_attributes(self, arn, attributes): endpoint = self.get_endpoint(arn) endpoint.attributes.update(attributes) return endpoint def delete_endpoint(self, arn): try: del self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError( "Endpoint with arn {0} not found".format(arn))
class AutoScalingBackend(BaseBackend): def __init__(self, ec2_backend, elb_backend, elbv2_backend): self.autoscaling_groups = OrderedDict() self.launch_configurations = OrderedDict() self.policies = {} self.ec2_backend = ec2_backend self.elb_backend = elb_backend self.elbv2_backend = elbv2_backend def reset(self): ec2_backend = self.ec2_backend elb_backend = self.elb_backend elbv2_backend = self.elbv2_backend self.__dict__ = {} self.__init__(ec2_backend, elb_backend, elbv2_backend) def create_launch_configuration( self, name, image_id, key_name, kernel_id, ramdisk_id, security_groups, user_data, instance_type, instance_monitoring, instance_profile_name, spot_price, ebs_optimized, associate_public_ip_address, block_device_mappings, ): launch_configuration = FakeLaunchConfiguration( name=name, image_id=image_id, key_name=key_name, kernel_id=kernel_id, ramdisk_id=ramdisk_id, security_groups=security_groups, user_data=user_data, instance_type=instance_type, instance_monitoring=instance_monitoring, instance_profile_name=instance_profile_name, spot_price=spot_price, ebs_optimized=ebs_optimized, associate_public_ip_address=associate_public_ip_address, block_device_mapping_dict=block_device_mappings, ) self.launch_configurations[name] = launch_configuration return launch_configuration def describe_launch_configurations(self, names): configurations = self.launch_configurations.values() if names: return [ configuration for configuration in configurations if configuration.name in names ] else: return list(configurations) def delete_launch_configuration(self, launch_configuration_name): self.launch_configurations.pop(launch_configuration_name, None) def create_auto_scaling_group( self, name, availability_zones, desired_capacity, max_size, min_size, launch_config_name, vpc_zone_identifier, default_cooldown, health_check_period, health_check_type, load_balancers, target_group_arns, placement_group, termination_policies, tags, new_instances_protected_from_scale_in=False, instance_id=None, ): def make_int(value): return int(value) if value is not None else value max_size = make_int(max_size) min_size = make_int(min_size) desired_capacity = make_int(desired_capacity) default_cooldown = make_int(default_cooldown) if health_check_period is None: health_check_period = 300 else: health_check_period = make_int(health_check_period) if launch_config_name is None and instance_id is not None: try: instance = self.ec2_backend.get_instance(instance_id) launch_config_name = name FakeLaunchConfiguration.create_from_instance( launch_config_name, instance, self) except InvalidInstanceIdError: raise InvalidInstanceError(instance_id) group = FakeAutoScalingGroup( name=name, availability_zones=availability_zones, desired_capacity=desired_capacity, max_size=max_size, min_size=min_size, launch_config_name=launch_config_name, vpc_zone_identifier=vpc_zone_identifier, default_cooldown=default_cooldown, health_check_period=health_check_period, health_check_type=health_check_type, load_balancers=load_balancers, target_group_arns=target_group_arns, placement_group=placement_group, termination_policies=termination_policies, autoscaling_backend=self, tags=tags, new_instances_protected_from_scale_in= new_instances_protected_from_scale_in, ) self.autoscaling_groups[name] = group self.update_attached_elbs(group.name) self.update_attached_target_groups(group.name) return group def update_auto_scaling_group( self, name, availability_zones, desired_capacity, max_size, min_size, launch_config_name, vpc_zone_identifier, default_cooldown, health_check_period, health_check_type, placement_group, termination_policies, new_instances_protected_from_scale_in=None, ): group = self.autoscaling_groups[name] group.update( availability_zones, desired_capacity, max_size, min_size, launch_config_name, vpc_zone_identifier, default_cooldown, health_check_period, health_check_type, placement_group, termination_policies, new_instances_protected_from_scale_in= new_instances_protected_from_scale_in, ) return group def describe_auto_scaling_groups(self, names): groups = self.autoscaling_groups.values() if names: return [group for group in groups if group.name in names] else: return list(groups) def delete_auto_scaling_group(self, group_name): self.set_desired_capacity(group_name, 0) self.autoscaling_groups.pop(group_name, None) def describe_auto_scaling_instances(self, instance_ids): instance_states = [] for group in self.autoscaling_groups.values(): instance_states.extend([ x for x in group.instance_states if not instance_ids or x.instance.id in instance_ids ]) return instance_states def attach_instances(self, group_name, instance_ids): group = self.autoscaling_groups[group_name] original_size = len(group.instance_states) if (original_size + len(instance_ids)) > group.max_size: raise ResourceContentionError else: group.desired_capacity = original_size + len(instance_ids) new_instances = [ InstanceState( self.ec2_backend.get_instance(x), protected_from_scale_in=group. new_instances_protected_from_scale_in, ) for x in instance_ids ] for instance in new_instances: self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) group.instance_states.extend(new_instances) self.update_attached_elbs(group.name) self.update_attached_target_groups(group.name) def set_instance_health(self, instance_id, health_status, should_respect_grace_period): instance = self.ec2_backend.get_instance(instance_id) instance_state = next(instance_state for group in self.autoscaling_groups.values() for instance_state in group.instance_states if instance_state.instance.id == instance.id) instance_state.health_status = health_status def detach_instances(self, group_name, instance_ids, should_decrement): group = self.autoscaling_groups[group_name] original_size = group.desired_capacity detached_instances = [ x for x in group.instance_states if x.instance.id in instance_ids ] for instance in detached_instances: self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) new_instance_state = [ x for x in group.instance_states if x.instance.id not in instance_ids ] group.instance_states = new_instance_state if should_decrement: group.desired_capacity = original_size - len(instance_ids) group.set_desired_capacity(group.desired_capacity) return detached_instances def set_desired_capacity(self, group_name, desired_capacity): group = self.autoscaling_groups[group_name] group.set_desired_capacity(desired_capacity) self.update_attached_elbs(group_name) def change_capacity(self, group_name, scaling_adjustment): group = self.autoscaling_groups[group_name] desired_capacity = group.desired_capacity + scaling_adjustment self.set_desired_capacity(group_name, desired_capacity) def change_capacity_percent(self, group_name, scaling_adjustment): """ http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/as-scale-based-on-demand.html If PercentChangeInCapacity returns a value between 0 and 1, Auto Scaling will round it off to 1. If the PercentChangeInCapacity returns a value greater than 1, Auto Scaling will round it off to the lower value. For example, if PercentChangeInCapacity returns 12.5, then Auto Scaling will round it off to 12.""" group = self.autoscaling_groups[group_name] percent_change = 1 + (scaling_adjustment / 100.0) desired_capacity = group.desired_capacity * percent_change if group.desired_capacity < desired_capacity < group.desired_capacity + 1: desired_capacity = group.desired_capacity + 1 else: desired_capacity = int(desired_capacity) self.set_desired_capacity(group_name, desired_capacity) def create_autoscaling_policy(self, name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown): policy = FakeScalingPolicy( name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown, self, ) self.policies[name] = policy return policy def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None): return [ policy for policy in self.policies.values() if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and ( not policy_names or policy.name in policy_names) and ( not policy_types or policy.policy_type in policy_types) ] def delete_policy(self, group_name): self.policies.pop(group_name, None) def execute_policy(self, group_name): policy = self.policies[group_name] policy.execute() def update_attached_elbs(self, group_name): group = self.autoscaling_groups[group_name] group_instance_ids = set(state.instance.id for state in group.active_instances()) # skip this if group.load_balancers is empty # otherwise elb_backend.describe_load_balancers returns all available load balancers if not group.load_balancers: return try: elbs = self.elb_backend.describe_load_balancers( names=group.load_balancers) except LoadBalancerNotFoundError: # ELBs can be deleted before their autoscaling group return for elb in elbs: elb_instace_ids = set(elb.instance_ids) self.elb_backend.register_instances( elb.name, group_instance_ids - elb_instace_ids) self.elb_backend.deregister_instances( elb.name, elb_instace_ids - group_instance_ids) def update_attached_target_groups(self, group_name): group = self.autoscaling_groups[group_name] group_instance_ids = set(state.instance.id for state in group.instance_states) # no action necessary if target_group_arns is empty if not group.target_group_arns: return target_groups = self.elbv2_backend.describe_target_groups( target_group_arns=group.target_group_arns, load_balancer_arn=None, names=None, ) for target_group in target_groups: asg_targets = [{ "id": x, "port": target_group.port } for x in group_instance_ids] self.elbv2_backend.register_targets(target_group.arn, (asg_targets)) def create_or_update_tags(self, tags): for tag in tags: group_name = tag["resource_id"] group = self.autoscaling_groups[group_name] old_tags = group.tags new_tags = [] # if key was in old_tags, update old tag for old_tag in old_tags: if old_tag["key"] == tag["key"]: new_tags.append(tag) else: new_tags.append(old_tag) # if key was never in old_tag's add it (create tag) if not any(new_tag["key"] == tag["key"] for new_tag in new_tags): new_tags.append(tag) group.tags = new_tags def attach_load_balancers(self, group_name, load_balancer_names): group = self.autoscaling_groups[group_name] group.load_balancers.extend( [x for x in load_balancer_names if x not in group.load_balancers]) self.update_attached_elbs(group_name) def describe_load_balancers(self, group_name): return self.autoscaling_groups[group_name].load_balancers def detach_load_balancers(self, group_name, load_balancer_names): group = self.autoscaling_groups[group_name] group_instance_ids = set(state.instance.id for state in group.instance_states) elbs = self.elb_backend.describe_load_balancers( names=group.load_balancers) for elb in elbs: self.elb_backend.deregister_instances(elb.name, group_instance_ids) group.load_balancers = [ x for x in group.load_balancers if x not in load_balancer_names ] def attach_load_balancer_target_groups(self, group_name, target_group_arns): group = self.autoscaling_groups[group_name] group.append_target_groups(target_group_arns) self.update_attached_target_groups(group_name) def describe_load_balancer_target_groups(self, group_name): return self.autoscaling_groups[group_name].target_group_arns def detach_load_balancer_target_groups(self, group_name, target_group_arns): group = self.autoscaling_groups[group_name] group.target_group_arns = [ x for x in group.target_group_arns if x not in target_group_arns ] for target_group in target_group_arns: asg_targets = [{ "id": x.instance.id } for x in group.instance_states] self.elbv2_backend.deregister_targets(target_group, (asg_targets)) def suspend_processes(self, group_name, scaling_processes): group = self.autoscaling_groups[group_name] group.suspended_processes = scaling_processes or [] def set_instance_protection(self, group_name, instance_ids, protected_from_scale_in): group = self.autoscaling_groups[group_name] protected_instances = [ x for x in group.instance_states if x.instance.id in instance_ids ] for instance in protected_instances: instance.protected_from_scale_in = protected_from_scale_in def notify_terminate_instances(self, instance_ids): for ( autoscaling_group_name, autoscaling_group, ) in self.autoscaling_groups.items(): original_active_instance_count = len( autoscaling_group.active_instances()) autoscaling_group.instance_states = list( filter( lambda i_state: i_state.instance.id not in instance_ids, autoscaling_group.instance_states, )) difference = original_active_instance_count - len( autoscaling_group.active_instances()) if difference > 0: autoscaling_group.replace_autoscaling_group_instances( difference, autoscaling_group.get_propagated_tags()) self.update_attached_elbs(autoscaling_group_name) def enter_standby_instances(self, group_name, instance_ids, should_decrement): group = self.autoscaling_groups[group_name] original_size = group.desired_capacity standby_instances = [] for instance_state in group.instance_states: if instance_state.instance.id in instance_ids: instance_state.lifecycle_state = "Standby" standby_instances.append(instance_state) if should_decrement: group.desired_capacity = group.desired_capacity - len(instance_ids) group.set_desired_capacity(group.desired_capacity) return standby_instances, original_size, group.desired_capacity def exit_standby_instances(self, group_name, instance_ids): group = self.autoscaling_groups[group_name] original_size = group.desired_capacity standby_instances = [] for instance_state in group.instance_states: if instance_state.instance.id in instance_ids: instance_state.lifecycle_state = "InService" standby_instances.append(instance_state) group.desired_capacity = group.desired_capacity + len(instance_ids) group.set_desired_capacity(group.desired_capacity) return standby_instances, original_size, group.desired_capacity def terminate_instance(self, instance_id, should_decrement): instance = self.ec2_backend.get_instance(instance_id) instance_state = next(instance_state for group in self.autoscaling_groups.values() for instance_state in group.instance_states if instance_state.instance.id == instance.id) group = instance.autoscaling_group original_size = group.desired_capacity self.detach_instances(group.name, [instance.id], should_decrement) self.ec2_backend.terminate_instances([instance.id]) return instance_state, original_size, group.desired_capacity
class CognitoIdpBackend(BaseBackend): def __init__(self, region): super(CognitoIdpBackend, self).__init__() self.region = region self.user_pools = OrderedDict() self.user_pool_domains = OrderedDict() self.sessions = {} def reset(self): region = self.region self.__dict__ = {} self.__init__(region) # User pool def create_user_pool(self, name, extended_config): user_pool = CognitoIdpUserPool(self.region, name, extended_config) self.user_pools[user_pool.id] = user_pool return user_pool @paginate(60) def list_user_pools(self, max_results=None, next_token=None): return self.user_pools.values() def describe_user_pool(self, user_pool_id): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) return user_pool def delete_user_pool(self, user_pool_id): if user_pool_id not in self.user_pools: raise ResourceNotFoundError(user_pool_id) del self.user_pools[user_pool_id] # User pool domain def create_user_pool_domain(self, user_pool_id, domain): if user_pool_id not in self.user_pools: raise ResourceNotFoundError(user_pool_id) user_pool_domain = CognitoIdpUserPoolDomain(user_pool_id, domain) self.user_pool_domains[domain] = user_pool_domain return user_pool_domain def describe_user_pool_domain(self, domain): if domain not in self.user_pool_domains: return None return self.user_pool_domains[domain] def delete_user_pool_domain(self, domain): if domain not in self.user_pool_domains: raise ResourceNotFoundError(domain) del self.user_pool_domains[domain] # User pool client def create_user_pool_client(self, user_pool_id, extended_config): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) user_pool_client = CognitoIdpUserPoolClient(user_pool_id, extended_config) user_pool.clients[user_pool_client.id] = user_pool_client return user_pool_client @paginate(60) def list_user_pool_clients(self, user_pool_id, max_results=None, next_token=None): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) return user_pool.clients.values() def describe_user_pool_client(self, user_pool_id, client_id): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) client = user_pool.clients.get(client_id) if not client: raise ResourceNotFoundError(client_id) return client def update_user_pool_client(self, user_pool_id, client_id, extended_config): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) client = user_pool.clients.get(client_id) if not client: raise ResourceNotFoundError(client_id) client.extended_config.update(extended_config) return client def delete_user_pool_client(self, user_pool_id, client_id): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) if client_id not in user_pool.clients: raise ResourceNotFoundError(client_id) del user_pool.clients[client_id] # Identity provider def create_identity_provider(self, user_pool_id, name, extended_config): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) identity_provider = CognitoIdpIdentityProvider(name, extended_config) user_pool.identity_providers[name] = identity_provider return identity_provider @paginate(60) def list_identity_providers(self, user_pool_id, max_results=None, next_token=None): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) return user_pool.identity_providers.values() def describe_identity_provider(self, user_pool_id, name): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) identity_provider = user_pool.identity_providers.get(name) if not identity_provider: raise ResourceNotFoundError(name) return identity_provider def update_identity_provider(self, user_pool_id, name, extended_config): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) identity_provider = user_pool.identity_providers.get(name) if not identity_provider: raise ResourceNotFoundError(name) identity_provider.extended_config.update(extended_config) return identity_provider def delete_identity_provider(self, user_pool_id, name): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) if name not in user_pool.identity_providers: raise ResourceNotFoundError(name) del user_pool.identity_providers[name] # Group def create_group(self, user_pool_id, group_name, description, role_arn, precedence): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) group = CognitoIdpGroup(user_pool_id, group_name, description, role_arn, precedence) if group.group_name in user_pool.groups: raise GroupExistsException("A group with the name already exists") user_pool.groups[group.group_name] = group return group def get_group(self, user_pool_id, group_name): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) if group_name not in user_pool.groups: raise ResourceNotFoundError(group_name) return user_pool.groups[group_name] def list_groups(self, user_pool_id): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) return user_pool.groups.values() def delete_group(self, user_pool_id, group_name): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) if group_name not in user_pool.groups: raise ResourceNotFoundError(group_name) group = user_pool.groups[group_name] for user in group.users: user.groups.remove(group) del user_pool.groups[group_name] def admin_add_user_to_group(self, user_pool_id, group_name, username): group = self.get_group(user_pool_id, group_name) user = self.admin_get_user(user_pool_id, username) group.users.add(user) user.groups.add(group) def list_users_in_group(self, user_pool_id, group_name): group = self.get_group(user_pool_id, group_name) return list(group.users) def admin_list_groups_for_user(self, user_pool_id, username): user = self.admin_get_user(user_pool_id, username) return list(user.groups) def admin_remove_user_from_group(self, user_pool_id, group_name, username): group = self.get_group(user_pool_id, group_name) user = self.admin_get_user(user_pool_id, username) group.users.discard(user) user.groups.discard(group) # User def admin_create_user(self, user_pool_id, username, temporary_password, attributes): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) user = CognitoIdpUser(user_pool_id, username, temporary_password, UserStatus["FORCE_CHANGE_PASSWORD"], attributes) user_pool.users[user.username] = user return user def admin_get_user(self, user_pool_id, username): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) if username not in user_pool.users: raise UserNotFoundError(username) return user_pool.users[username] @paginate(60, "pagination_token", "limit") def list_users(self, user_pool_id, pagination_token=None, limit=None): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) return user_pool.users.values() def admin_disable_user(self, user_pool_id, username): user = self.admin_get_user(user_pool_id, username) user.enabled = False def admin_enable_user(self, user_pool_id, username): user = self.admin_get_user(user_pool_id, username) user.enabled = True def admin_delete_user(self, user_pool_id, username): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) if username not in user_pool.users: raise UserNotFoundError(username) user = user_pool.users[username] for group in user.groups: group.users.remove(user) del user_pool.users[username] def _log_user_in(self, user_pool, client, username): refresh_token = user_pool.create_refresh_token(client.id, username) access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) return { "AuthenticationResult": { "IdToken": id_token, "AccessToken": access_token, "RefreshToken": refresh_token, "ExpiresIn": expires_in, } } def admin_initiate_auth(self, user_pool_id, client_id, auth_flow, auth_parameters): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) client = user_pool.clients.get(client_id) if not client: raise ResourceNotFoundError(client_id) if auth_flow == "ADMIN_NO_SRP_AUTH": username = auth_parameters.get("USERNAME") password = auth_parameters.get("PASSWORD") user = user_pool.users.get(username) if not user: raise UserNotFoundError(username) if user.password != password: raise NotAuthorizedError(username) if user.status == UserStatus["FORCE_CHANGE_PASSWORD"]: session = str(uuid.uuid4()) self.sessions[session] = user_pool return { "ChallengeName": "NEW_PASSWORD_REQUIRED", "ChallengeParameters": {}, "Session": session, } return self._log_user_in(user_pool, client, username) elif auth_flow == "REFRESH_TOKEN": refresh_token = auth_parameters.get("REFRESH_TOKEN") id_token, access_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) return { "AuthenticationResult": { "IdToken": id_token, "AccessToken": access_token, "ExpiresIn": expires_in, } } else: return {} def respond_to_auth_challenge(self, session, client_id, challenge_name, challenge_responses): user_pool = self.sessions.get(session) if not user_pool: raise ResourceNotFoundError(session) client = user_pool.clients.get(client_id) if not client: raise ResourceNotFoundError(client_id) if challenge_name == "NEW_PASSWORD_REQUIRED": username = challenge_responses.get("USERNAME") new_password = challenge_responses.get("NEW_PASSWORD") user = user_pool.users.get(username) if not user: raise UserNotFoundError(username) user.password = new_password user.status = UserStatus["CONFIRMED"] del self.sessions[session] return self._log_user_in(user_pool, client, username) else: return {} def confirm_forgot_password(self, client_id, username, password): for user_pool in self.user_pools.values(): if client_id in user_pool.clients and username in user_pool.users: user_pool.users[username].password = password break else: raise ResourceNotFoundError(client_id) def change_password(self, access_token, previous_password, proposed_password): for user_pool in self.user_pools.values(): if access_token in user_pool.access_tokens: _, username = user_pool.access_tokens[access_token] user = user_pool.users.get(username) if not user: raise UserNotFoundError(username) if user.password != previous_password: raise NotAuthorizedError(username) user.password = proposed_password if user.status == UserStatus["FORCE_CHANGE_PASSWORD"]: user.status = UserStatus["CONFIRMED"] break else: raise NotAuthorizedError(access_token) def admin_update_user_attributes(self, user_pool_id, username, attributes): user_pool = self.user_pools.get(user_pool_id) if not user_pool: raise ResourceNotFoundError(user_pool_id) if username not in user_pool.users: raise UserNotFoundError(username) user = user_pool.users[username] user.update_attributes(attributes)
class CognitoIdpUserPool(BaseModel): def __init__(self, region, name, extended_config): self.region = region self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) self.name = name self.status = None self.extended_config = extended_config or {} self.creation_date = datetime.datetime.utcnow() self.last_modified_date = datetime.datetime.utcnow() self.clients = OrderedDict() self.identity_providers = OrderedDict() self.groups = OrderedDict() self.users = OrderedDict() self.refresh_tokens = {} self.access_tokens = {} self.id_tokens = {} with open(os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f: self.json_web_key = json.loads(f.read()) def _base_json(self): return { "Id": self.id, "Name": self.name, "Status": self.status, "CreationDate": time.mktime(self.creation_date.timetuple()), "LastModifiedDate": time.mktime(self.last_modified_date.timetuple()), } def to_json(self, extended=False): user_pool_json = self._base_json() if extended: user_pool_json.update(self.extended_config) else: user_pool_json["LambdaConfig"] = self.extended_config.get("LambdaConfig") or {} return user_pool_json def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}): now = int(time.time()) payload = { "iss": "https://cognito-idp.{}.amazonaws.com/{}".format(self.region, self.id), "sub": self.users[username].id, "aud": client_id, "token_use": "id", "auth_time": now, "exp": now + expires_in, } payload.update(extra_data) return jws.sign(payload, self.json_web_key, algorithm='RS256'), expires_in def create_id_token(self, client_id, username): id_token, expires_in = self.create_jwt(client_id, username) self.id_tokens[id_token] = (client_id, username) return id_token, expires_in def create_refresh_token(self, client_id, username): refresh_token = str(uuid.uuid4()) self.refresh_tokens[refresh_token] = (client_id, username) return refresh_token def create_access_token(self, client_id, username): extra_data = self.get_user_extra_data_by_client_id( client_id, username ) access_token, expires_in = self.create_jwt(client_id, username, extra_data=extra_data) self.access_tokens[access_token] = (client_id, username) return access_token, expires_in def create_tokens_from_refresh_token(self, refresh_token): client_id, username = self.refresh_tokens.get(refresh_token) if not username: raise NotAuthorizedError(refresh_token) access_token, expires_in = self.create_access_token(client_id, username) id_token, _ = self.create_id_token(client_id, username) return access_token, id_token, expires_in def get_user_extra_data_by_client_id(self, client_id, username): extra_data = {} current_client = self.clients.get(client_id, None) if current_client: for readable_field in current_client.get_readable_fields(): attribute = list(filter( lambda f: f['Name'] == readable_field, self.users.get(username).attributes )) if len(attribute) > 0: extra_data.update({ attribute[0]['Name']: attribute[0]['Value'] }) return extra_data
class CloudFormationBackend(BaseBackend): def __init__(self): self.stacks = OrderedDict() self.deleted_stacks = {} self.exports = OrderedDict() self.change_sets = OrderedDict() def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False): stack_id = generate_stack_id(name) new_stack = FakeStack( stack_id=stack_id, name=name, template=template, parameters=parameters, region_name=region_name, notification_arns=notification_arns, tags=tags, role_arn=role_arn, cross_stack_resources=self.exports, create_change_set=create_change_set, ) self.stacks[stack_id] = new_stack self._validate_export_uniqueness(new_stack) for export in new_stack.exports: self.exports[export.name] = export return new_stack def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None): if change_set_type == 'UPDATE': stacks = self.stacks.values() stack = None for s in stacks: if s.name == stack_name: stack = s if stack is None: raise ValidationError(stack_name) else: stack = self.create_stack(stack_name, template, parameters, region_name, notification_arns, tags, role_arn, create_change_set=True) change_set_id = generate_changeset_id(change_set_name, region_name) self.stacks[change_set_name] = { 'Id': change_set_id, 'StackId': stack.stack_id } self.change_sets[change_set_id] = stack return change_set_id, stack.stack_id def execute_change_set(self, change_set_name, stack_name=None): stack = None if change_set_name in self.change_sets: # This means arn was passed in stack = self.change_sets[change_set_name] else: for cs in self.change_sets: if self.change_sets[cs].name == change_set_name: stack = self.change_sets[cs] if stack is None: raise ValidationError(stack_name) if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS': stack._add_stack_event('CREATE_COMPLETE') else: stack._add_stack_event('UPDATE_IN_PROGRESS') stack._add_stack_event('UPDATE_COMPLETE') return True def describe_stacks(self, name_or_stack_id): stacks = self.stacks.values() if name_or_stack_id: for stack in stacks: if stack.name == name_or_stack_id or stack.stack_id == name_or_stack_id: return [stack] if self.deleted_stacks: deleted_stacks = self.deleted_stacks.values() for stack in deleted_stacks: if stack.stack_id == name_or_stack_id: return [stack] raise ValidationError(name_or_stack_id) else: return list(stacks) def list_stacks(self): return self.stacks.values() def get_stack(self, name_or_stack_id): all_stacks = dict(self.deleted_stacks, **self.stacks) if name_or_stack_id in all_stacks: # Lookup by stack id - deleted stacks incldued return all_stacks[name_or_stack_id] else: # Lookup by stack name - undeleted stacks only for stack in self.stacks.values(): if stack.name == name_or_stack_id: return stack def update_stack(self, name, template, role_arn=None, parameters=None, tags=None): stack = self.get_stack(name) stack.update(template, role_arn, parameters=parameters, tags=tags) return stack def list_stack_resources(self, stack_name_or_id): stack = self.get_stack(stack_name_or_id) return stack.stack_resources def delete_stack(self, name_or_stack_id): if name_or_stack_id in self.stacks: # Delete by stack id stack = self.stacks.pop(name_or_stack_id, None) stack.delete() self.deleted_stacks[stack.stack_id] = stack [self.exports.pop(export.name) for export in stack.exports] return self.stacks.pop(name_or_stack_id, None) else: # Delete by stack name for stack in list(self.stacks.values()): if stack.name == name_or_stack_id: self.delete_stack(stack.stack_id) def list_exports(self, token): all_exports = list(self.exports.values()) if token is None: exports = all_exports[0:100] next_token = '100' if len(all_exports) > 100 else None else: token = int(token) exports = all_exports[token:token + 100] next_token = str(token + 100) if len(all_exports) > token + 100 else None return exports, next_token def _validate_export_uniqueness(self, stack): new_stack_export_names = [x.name for x in stack.exports] export_names = self.exports.keys() if not set(export_names).isdisjoint(new_stack_export_names): raise ValidationError( stack.stack_id, message='Export names must be unique across a given region')
class CloudFormationBackend(BaseBackend): def __init__(self): self.stacks = OrderedDict() self.stacksets = OrderedDict() self.deleted_stacks = {} self.exports = OrderedDict() self.change_sets = OrderedDict() def create_stack_set( self, name, template, parameters, tags=None, description=None, region="us-east-1", admin_role=None, execution_role=None, ): stackset_id = generate_stackset_id(name) new_stackset = FakeStackSet( stackset_id=stackset_id, name=name, template=template, parameters=parameters, description=description, tags=tags, admin_role=admin_role, execution_role=execution_role, ) self.stacksets[stackset_id] = new_stackset return new_stackset def get_stack_set(self, name): stacksets = self.stacksets.keys() for stackset in stacksets: if self.stacksets[stackset].name == name: return self.stacksets[stackset] raise ValidationError(name) def delete_stack_set(self, name): stacksets = self.stacksets.keys() for stackset in stacksets: if self.stacksets[stackset].name == name: self.stacksets[stackset].delete() def create_stack_instances( self, stackset_name, accounts, regions, parameters, operation_id=None ): stackset = self.get_stack_set(stackset_name) stackset.create_stack_instances( accounts=accounts, regions=regions, parameters=parameters, operation_id=operation_id, ) return stackset def update_stack_set( self, stackset_name, template=None, description=None, parameters=None, tags=None, admin_role=None, execution_role=None, accounts=None, regions=None, operation_id=None, ): stackset = self.get_stack_set(stackset_name) update = stackset.update( template=template, description=description, parameters=parameters, tags=tags, admin_role=admin_role, execution_role=execution_role, accounts=accounts, regions=regions, operation_id=operation_id, ) return update def delete_stack_instances( self, stackset_name, accounts, regions, operation_id=None ): stackset = self.get_stack_set(stackset_name) stackset.delete_stack_instances(accounts, regions, operation_id) return stackset def create_stack( self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False, ): stack_id = generate_stack_id(name) new_stack = FakeStack( stack_id=stack_id, name=name, template=template, parameters=parameters, region_name=region_name, notification_arns=notification_arns, tags=tags, role_arn=role_arn, cross_stack_resources=self.exports, create_change_set=create_change_set, ) self.stacks[stack_id] = new_stack self._validate_export_uniqueness(new_stack) for export in new_stack.exports: self.exports[export.name] = export return new_stack def create_change_set( self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None, ): stack_id = None stack_template = None if change_set_type == "UPDATE": stacks = self.stacks.values() stack = None for s in stacks: if s.name == stack_name: stack = s stack_id = stack.stack_id stack_template = stack.template if stack is None: raise ValidationError(stack_name) else: stack_id = generate_stack_id(stack_name) stack_template = template change_set_id = generate_changeset_id(change_set_name, region_name) new_change_set = FakeChangeSet( stack_id=stack_id, stack_name=stack_name, stack_template=stack_template, change_set_id=change_set_id, change_set_name=change_set_name, template=template, parameters=parameters, region_name=region_name, notification_arns=notification_arns, tags=tags, role_arn=role_arn, cross_stack_resources=self.exports, ) self.change_sets[change_set_id] = new_change_set self.stacks[stack_id] = new_change_set return change_set_id, stack_id def delete_change_set(self, change_set_name, stack_name=None): if change_set_name in self.change_sets: # This means arn was passed in del self.change_sets[change_set_name] else: for cs in self.change_sets: if self.change_sets[cs].change_set_name == change_set_name: del self.change_sets[cs] def describe_change_set(self, change_set_name, stack_name=None): change_set = None if change_set_name in self.change_sets: # This means arn was passed in change_set = self.change_sets[change_set_name] else: for cs in self.change_sets: if self.change_sets[cs].change_set_name == change_set_name: change_set = self.change_sets[cs] if change_set is None: raise ValidationError(change_set_name) return change_set def execute_change_set(self, change_set_name, stack_name=None): stack = None if change_set_name in self.change_sets: # This means arn was passed in stack = self.change_sets[change_set_name] else: for cs in self.change_sets: if self.change_sets[cs].change_set_name == change_set_name: stack = self.change_sets[cs] if stack is None: raise ValidationError(stack_name) if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS": stack._add_stack_event("CREATE_COMPLETE") else: stack._add_stack_event("UPDATE_IN_PROGRESS") stack._add_stack_event("UPDATE_COMPLETE") stack.create_resources() return True def describe_stacks(self, name_or_stack_id): stacks = self.stacks.values() if name_or_stack_id: for stack in stacks: if stack.name == name_or_stack_id or stack.stack_id == name_or_stack_id: return [stack] if self.deleted_stacks: deleted_stacks = self.deleted_stacks.values() for stack in deleted_stacks: if stack.stack_id == name_or_stack_id: return [stack] raise ValidationError(name_or_stack_id) else: return list(stacks) def list_change_sets(self): return self.change_sets.values() def list_stacks(self): return [v for v in self.stacks.values()] + [ v for v in self.deleted_stacks.values() ] def get_stack(self, name_or_stack_id): all_stacks = dict(self.deleted_stacks, **self.stacks) if name_or_stack_id in all_stacks: # Lookup by stack id - deleted stacks incldued return all_stacks[name_or_stack_id] else: # Lookup by stack name - undeleted stacks only for stack in self.stacks.values(): if stack.name == name_or_stack_id: return stack def update_stack(self, name, template, role_arn=None, parameters=None, tags=None): stack = self.get_stack(name) stack.update(template, role_arn, parameters=parameters, tags=tags) return stack def list_stack_resources(self, stack_name_or_id): stack = self.get_stack(stack_name_or_id) if stack is None: return None return stack.stack_resources def delete_stack(self, name_or_stack_id): if name_or_stack_id in self.stacks: # Delete by stack id stack = self.stacks.pop(name_or_stack_id, None) stack.delete() self.deleted_stacks[stack.stack_id] = stack [self.exports.pop(export.name) for export in stack.exports] return self.stacks.pop(name_or_stack_id, None) else: # Delete by stack name for stack in list(self.stacks.values()): if stack.name == name_or_stack_id: self.delete_stack(stack.stack_id) def list_exports(self, token): all_exports = list(self.exports.values()) if token is None: exports = all_exports[0:100] next_token = "100" if len(all_exports) > 100 else None else: token = int(token) exports = all_exports[token : token + 100] next_token = str(token + 100) if len(all_exports) > token + 100 else None return exports, next_token def validate_template(self, template): return validate_template_cfn_lint(template) def _validate_export_uniqueness(self, stack): new_stack_export_names = [x.name for x in stack.exports] export_names = self.exports.keys() if not set(export_names).isdisjoint(new_stack_export_names): raise ValidationError( stack.stack_id, message="Export names must be unique across a given region", )
class SNSBackend(BaseBackend): def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() self.subscriptions = OrderedDict() self.applications = {} self.platform_endpoints = {} self.region_name = region_name self.sms_attributes = {} self.sms_messages = OrderedDict() self.opt_out_numbers = [ "+447420500600", "+447420505401", "+447632960543", "+447632960028", "+447700900149", "+447700900550", "+447700900545", "+447700900907", ] def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def update_sms_attributes(self, attrs): self.sms_attributes.update(attrs) def create_topic(self, name, attributes=None, tags=None): if attributes is None: attributes = {} if ( attributes.get("FifoTopic") and attributes.get("FifoTopic").lower() == "true" ): fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}\.fifo$", name) msg = "Fifo Topic names must end with .fifo and must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long." else: fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}$", name) msg = "Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long." if fails_constraints: raise InvalidParameterValue(msg) candidate_topic = Topic(name, self) if attributes: for attribute in attributes: setattr( candidate_topic, camelcase_to_underscores(attribute), attributes[attribute], ) if tags: candidate_topic._tags = tags if candidate_topic.arn in self.topics: return self.topics[candidate_topic.arn] else: self.topics[candidate_topic.arn] = candidate_topic return candidate_topic def _get_values_nexttoken(self, values_map, next_token=None): if next_token is None or not next_token: next_token = 0 next_token = int(next_token) values = list(values_map.values())[next_token : next_token + DEFAULT_PAGE_SIZE] if len(values) == DEFAULT_PAGE_SIZE: next_token = next_token + DEFAULT_PAGE_SIZE else: next_token = None return values, next_token def _get_topic_subscriptions(self, topic): return [sub for sub in self.subscriptions.values() if sub.topic == topic] def list_topics(self, next_token=None): return self._get_values_nexttoken(self.topics, next_token) def delete_topic_subscriptions(self, topic): for key, value in self.subscriptions.items(): if value.topic == topic: self.subscriptions.pop(key) def delete_topic(self, arn): try: topic = self.get_topic(arn) self.delete_topic_subscriptions(topic) self.topics.pop(arn) except KeyError: raise SNSNotFoundError("Topic with arn {0} not found".format(arn)) def get_topic(self, arn): try: return self.topics[arn] except KeyError: raise SNSNotFoundError("Topic with arn {0} not found".format(arn)) def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) setattr(topic, attribute_name, attribute_value) def subscribe(self, topic_arn, endpoint, protocol): if protocol == "sms": if re.search(r"[./-]{2,}", endpoint) or re.search( r"(^[./-]|[./-]$)", endpoint ): raise SNSInvalidParameter("Invalid SMS endpoint: {}".format(endpoint)) reduced_endpoint = re.sub(r"[./-]", "", endpoint) if not is_e164(reduced_endpoint): raise SNSInvalidParameter("Invalid SMS endpoint: {}".format(endpoint)) # AWS doesn't create duplicates old_subscription = self._find_subscription(topic_arn, endpoint, protocol) if old_subscription: return old_subscription topic = self.get_topic(topic_arn) subscription = Subscription(topic, endpoint, protocol) attributes = { "PendingConfirmation": "false", "ConfirmationWasAuthenticated": "true", "Endpoint": endpoint, "TopicArn": topic_arn, "Protocol": protocol, "SubscriptionArn": subscription.arn, "Owner": DEFAULT_ACCOUNT_ID, "RawMessageDelivery": "false", } if protocol in ["http", "https"]: attributes["EffectiveDeliveryPolicy"] = topic.effective_delivery_policy subscription.attributes = attributes self.subscriptions[subscription.arn] = subscription return subscription def _find_subscription(self, topic_arn, endpoint, protocol): for subscription in self.subscriptions.values(): if ( subscription.topic.arn == topic_arn and subscription.endpoint == endpoint and subscription.protocol == protocol ): return subscription return None def unsubscribe(self, subscription_arn): self.subscriptions.pop(subscription_arn, None) def list_subscriptions(self, topic_arn=None, next_token=None): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict( [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)] ) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) def publish( self, message, arn=None, phone_number=None, subject=None, message_attributes=None, ): if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503 raise ValueError("Subject must be less than 100 characters") if phone_number: # This is only an approximation. In fact, we should try to use GSM-7 or UCS-2 encoding to count used bytes if len(message) > MAXIMUM_SMS_MESSAGE_BYTES: raise ValueError("SMS message must be less than 1600 bytes") message_id = six.text_type(uuid.uuid4()) self.sms_messages[message_id] = (phone_number, message) return message_id if len(message) > MAXIMUM_MESSAGE_LENGTH: raise InvalidParameterValue( "An error occurred (InvalidParameter) when calling the Publish operation: Invalid parameter: Message too long" ) try: topic = self.get_topic(arn) message_id = topic.publish( message, subject=subject, message_attributes=message_attributes ) except SNSNotFoundError: endpoint = self.get_endpoint(arn) message_id = endpoint.publish(message) return message_id def create_platform_application(self, region, name, platform, attributes): application = PlatformApplication(region, name, platform, attributes) self.applications[application.arn] = application return application def get_application(self, arn): try: return self.applications[arn] except KeyError: raise SNSNotFoundError("Application with arn {0} not found".format(arn)) def set_application_attributes(self, arn, attributes): application = self.get_application(arn) application.attributes.update(attributes) return application def list_platform_applications(self): return self.applications.values() def delete_platform_application(self, platform_arn): self.applications.pop(platform_arn) def create_platform_endpoint( self, region, application, custom_user_data, token, attributes ): if any( token == endpoint.token for endpoint in self.platform_endpoints.values() ): raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) platform_endpoint = PlatformEndpoint( region, application, custom_user_data, token, attributes ) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint def list_endpoints_by_platform_application(self, application_arn): return [ endpoint for endpoint in self.platform_endpoints.values() if endpoint.application.arn == application_arn ] def get_endpoint(self, arn): try: return self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError("Endpoint does not exist") def set_endpoint_attributes(self, arn, attributes): endpoint = self.get_endpoint(arn) if "Enabled" in attributes: attributes["Enabled"] = attributes["Enabled"].lower() endpoint.attributes.update(attributes) return endpoint def delete_endpoint(self, arn): try: del self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError("Endpoint with arn {0} not found".format(arn)) def get_subscription_attributes(self, arn): _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] if not _subscription: raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) subscription = _subscription[0] return subscription.attributes def set_subscription_attributes(self, arn, name, value): if name not in [ "RawMessageDelivery", "DeliveryPolicy", "FilterPolicy", "RedrivePolicy", ]: raise SNSInvalidParameter("AttributeName") # TODO: should do validation _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] if not _subscription: raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) subscription = _subscription[0] subscription.attributes[name] = value if name == "FilterPolicy": filter_policy = json.loads(value) self._validate_filter_policy(filter_policy) subscription._filter_policy = filter_policy def _validate_filter_policy(self, value): # TODO: extend validation checks combinations = 1 for rules in six.itervalues(value): combinations *= len(rules) # Even the official documentation states the total combination of values must not exceed 100, in reality it is 150 # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints if combinations > 150: raise SNSInvalidParameter( "Invalid parameter: FilterPolicy: Filter policy is too complex" ) for field, rules in six.iteritems(value): for rule in rules: if rule is None: continue if isinstance(rule, six.string_types): continue if isinstance(rule, bool): continue if isinstance(rule, (six.integer_types, float)): if rule <= -1000000000 or rule >= 1000000000: raise InternalError("Unknown") continue if isinstance(rule, dict): keyword = list(rule.keys())[0] attributes = list(rule.values())[0] if keyword == "anything-but": continue elif keyword == "exists": if not isinstance(attributes, bool): raise SNSInvalidParameter( "Invalid parameter: FilterPolicy: exists match pattern must be either true or false." ) continue elif keyword == "numeric": continue elif keyword == "prefix": continue else: raise SNSInvalidParameter( "Invalid parameter: FilterPolicy: Unrecognized match type {type}".format( type=keyword ) ) raise SNSInvalidParameter( "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" ) def add_permission(self, topic_arn, label, aws_account_ids, action_names): if topic_arn not in self.topics: raise SNSNotFoundError("Topic does not exist") policy = self.topics[topic_arn]._policy_json statement = next( ( statement for statement in policy["Statement"] if statement["Sid"] == label ), None, ) if statement: raise SNSInvalidParameter("Statement already exists") if any(action_name not in VALID_POLICY_ACTIONS for action_name in action_names): raise SNSInvalidParameter("Policy statement action out of service scope!") principals = [ "arn:aws:iam::{}:root".format(account_id) for account_id in aws_account_ids ] actions = ["SNS:{}".format(action_name) for action_name in action_names] statement = { "Sid": label, "Effect": "Allow", "Principal": {"AWS": principals[0] if len(principals) == 1 else principals}, "Action": actions[0] if len(actions) == 1 else actions, "Resource": topic_arn, } self.topics[topic_arn]._policy_json["Statement"].append(statement) def remove_permission(self, topic_arn, label): if topic_arn not in self.topics: raise SNSNotFoundError("Topic does not exist") statements = self.topics[topic_arn]._policy_json["Statement"] statements = [ statement for statement in statements if statement["Sid"] != label ] self.topics[topic_arn]._policy_json["Statement"] = statements def list_tags_for_resource(self, resource_arn): if resource_arn not in self.topics: raise ResourceNotFoundError return self.topics[resource_arn]._tags def tag_resource(self, resource_arn, tags): if resource_arn not in self.topics: raise ResourceNotFoundError updated_tags = self.topics[resource_arn]._tags.copy() updated_tags.update(tags) if len(updated_tags) > 50: raise TagLimitExceededError self.topics[resource_arn]._tags = updated_tags def untag_resource(self, resource_arn, tag_keys): if resource_arn not in self.topics: raise ResourceNotFoundError for key in tag_keys: self.topics[resource_arn]._tags.pop(key, None)
def __init__(self): self.stacks = OrderedDict() self.stacksets = OrderedDict() self.deleted_stacks = {} self.exports = OrderedDict() self.change_sets = OrderedDict()
class SNSBackend(BaseBackend): def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() self.subscriptions = OrderedDict() self.applications = {} self.platform_endpoints = {} self.region_name = region_name self.sms_attributes = {} self.opt_out_numbers = ['+447420500600', '+447420505401', '+447632960543', '+447632960028', '+447700900149', '+447700900550', '+447700900545', '+447700900907'] self.permissions = {} def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def update_sms_attributes(self, attrs): self.sms_attributes.update(attrs) def create_topic(self, name, attributes=None): fails_constraints = not re.match(r'^[a-zA-Z0-9_-]{1,256}$', name) if fails_constraints: raise InvalidParameterValue("Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long.") candidate_topic = Topic(name, self) if attributes: for attribute in attributes: setattr(candidate_topic, camelcase_to_underscores(attribute), attributes[attribute]) if candidate_topic.arn in self.topics: return self.topics[candidate_topic.arn] else: self.topics[candidate_topic.arn] = candidate_topic return candidate_topic def _get_values_nexttoken(self, values_map, next_token=None): if next_token is None or not next_token: next_token = 0 next_token = int(next_token) values = list(values_map.values())[ next_token: next_token + DEFAULT_PAGE_SIZE] if len(values) == DEFAULT_PAGE_SIZE: next_token = next_token + DEFAULT_PAGE_SIZE else: next_token = None return values, next_token def _get_topic_subscriptions(self, topic): return [sub for sub in self.subscriptions.values() if sub.topic == topic] def list_topics(self, next_token=None): return self._get_values_nexttoken(self.topics, next_token) def delete_topic(self, arn): topic = self.get_topic(arn) subscriptions = self._get_topic_subscriptions(topic) for sub in subscriptions: self.unsubscribe(sub.arn) self.topics.pop(arn) def get_topic(self, arn): try: return self.topics[arn] except KeyError: raise SNSNotFoundError("Topic with arn {0} not found".format(arn)) def get_topic_from_phone_number(self, number): for subscription in self.subscriptions.values(): if subscription.protocol == 'sms' and subscription.endpoint == number: return subscription.topic.arn raise SNSNotFoundError('Could not find valid subscription') def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) setattr(topic, attribute_name, attribute_value) def subscribe(self, topic_arn, endpoint, protocol): # AWS doesn't create duplicates old_subscription = self._find_subscription(topic_arn, endpoint, protocol) if old_subscription: return old_subscription topic = self.get_topic(topic_arn) subscription = Subscription(topic, endpoint, protocol) self.subscriptions[subscription.arn] = subscription return subscription def _find_subscription(self, topic_arn, endpoint, protocol): for subscription in self.subscriptions.values(): if subscription.topic.arn == topic_arn and subscription.endpoint == endpoint and subscription.protocol == protocol: return subscription return None def unsubscribe(self, subscription_arn): self.subscriptions.pop(subscription_arn) def list_subscriptions(self, topic_arn=None, next_token=None): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict( [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)]) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) def publish(self, arn, message, subject=None, message_attributes=None): if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503 raise ValueError('Subject must be less than 100 characters') if len(message) > MAXIMUM_MESSAGE_LENGTH: raise InvalidParameterValue("An error occurred (InvalidParameter) when calling the Publish operation: Invalid parameter: Message too long") try: topic = self.get_topic(arn) message_id = topic.publish(message, subject=subject, message_attributes=message_attributes) except SNSNotFoundError: endpoint = self.get_endpoint(arn) message_id = endpoint.publish(message) return message_id def create_platform_application(self, region, name, platform, attributes): application = PlatformApplication(region, name, platform, attributes) self.applications[application.arn] = application return application def get_application(self, arn): try: return self.applications[arn] except KeyError: raise SNSNotFoundError( "Application with arn {0} not found".format(arn)) def set_application_attributes(self, arn, attributes): application = self.get_application(arn) application.attributes.update(attributes) return application def list_platform_applications(self): return self.applications.values() def delete_platform_application(self, platform_arn): self.applications.pop(platform_arn) def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): if any(token == endpoint.token for endpoint in self.platform_endpoints.values()): raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) platform_endpoint = PlatformEndpoint( region, application, custom_user_data, token, attributes) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint def list_endpoints_by_platform_application(self, application_arn): return [ endpoint for endpoint in self.platform_endpoints.values() if endpoint.application.arn == application_arn ] def get_endpoint(self, arn): try: return self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError( "Endpoint with arn {0} not found".format(arn)) def set_endpoint_attributes(self, arn, attributes): endpoint = self.get_endpoint(arn) endpoint.attributes.update(attributes) return endpoint def delete_endpoint(self, arn): try: del self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError( "Endpoint with arn {0} not found".format(arn)) def get_subscription_attributes(self, arn): _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] if not _subscription: raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) subscription = _subscription[0] return subscription.attributes def set_subscription_attributes(self, arn, name, value): if name not in ['RawMessageDelivery', 'DeliveryPolicy', 'FilterPolicy']: raise SNSInvalidParameter('AttributeName') # TODO: should do validation _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] if not _subscription: raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) subscription = _subscription[0] subscription.attributes[name] = value if name == 'FilterPolicy': subscription._filter_policy = json.loads(value)
class KinesisBackend(BaseBackend): def __init__(self): self.streams = OrderedDict() self.delivery_streams = {} def create_stream(self, stream_name, shard_count, region): if stream_name in self.streams: raise ResourceInUseError(stream_name) stream = Stream(stream_name, shard_count, region) self.streams[stream_name] = stream return stream def describe_stream(self, stream_name): if stream_name in self.streams: return self.streams[stream_name] else: raise StreamNotFoundError(stream_name) def list_streams(self): return self.streams.values() def delete_stream(self, stream_name): if stream_name in self.streams: return self.streams.pop(stream_name) raise StreamNotFoundError(stream_name) def get_shard_iterator(self, stream_name, shard_id, shard_iterator_type, starting_sequence_number, at_timestamp): # Validate params stream = self.describe_stream(stream_name) shard = stream.get_shard(shard_id) shard_iterator = compose_new_shard_iterator( stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp ) return shard_iterator def get_records(self, shard_iterator, limit): decomposed = decompose_shard_iterator(shard_iterator) stream_name, shard_id, last_sequence_id = decomposed stream = self.describe_stream(stream_name) shard = stream.get_shard(shard_id) records, last_sequence_id, millis_behind_latest = shard.get_records(last_sequence_id, limit) next_shard_iterator = compose_shard_iterator( stream_name, shard, last_sequence_id) return next_shard_iterator, records, millis_behind_latest def put_record(self, stream_name, partition_key, explicit_hash_key, sequence_number_for_ordering, data): stream = self.describe_stream(stream_name) sequence_number, shard_id = stream.put_record( partition_key, explicit_hash_key, sequence_number_for_ordering, data ) return sequence_number, shard_id def put_records(self, stream_name, records): stream = self.describe_stream(stream_name) response = { "FailedRecordCount": 0, "Records": [] } for record in records: partition_key = record.get("PartitionKey") explicit_hash_key = record.get("ExplicitHashKey") data = record.get("Data") sequence_number, shard_id = stream.put_record( partition_key, explicit_hash_key, None, data ) response['Records'].append({ "SequenceNumber": sequence_number, "ShardId": shard_id }) return response def split_shard(self, stream_name, shard_to_split, new_starting_hash_key): stream = self.describe_stream(stream_name) if shard_to_split not in stream.shards: raise ResourceNotFoundError(shard_to_split) if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key): raise InvalidArgumentError(new_starting_hash_key) new_starting_hash_key = int(new_starting_hash_key) shard = stream.shards[shard_to_split] last_id = sorted(stream.shards.values(), key=attrgetter('_shard_id'))[-1]._shard_id if shard.starting_hash < new_starting_hash_key < shard.ending_hash: new_shard = Shard( last_id + 1, new_starting_hash_key, shard.ending_hash) shard.ending_hash = new_starting_hash_key stream.shards[new_shard.shard_id] = new_shard else: raise InvalidArgumentError(new_starting_hash_key) records = shard.records shard.records = OrderedDict() for index in records: record = records[index] stream.put_record( record.partition_key, record.explicit_hash_key, None, record.data ) def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge): stream = self.describe_stream(stream_name) if shard_to_merge not in stream.shards: raise ResourceNotFoundError(shard_to_merge) if adjacent_shard_to_merge not in stream.shards: raise ResourceNotFoundError(adjacent_shard_to_merge) shard1 = stream.shards[shard_to_merge] shard2 = stream.shards[adjacent_shard_to_merge] if shard1.ending_hash == shard2.starting_hash: shard1.ending_hash = shard2.ending_hash elif shard2.ending_hash == shard1.starting_hash: shard1.starting_hash = shard2.starting_hash else: raise InvalidArgumentError(adjacent_shard_to_merge) del stream.shards[shard2.shard_id] for index in shard2.records: record = shard2.records[index] shard1.put_record(record.partition_key, record.data, record.explicit_hash_key) ''' Firehose ''' def create_delivery_stream(self, stream_name, **stream_kwargs): stream = DeliveryStream(stream_name, **stream_kwargs) self.delivery_streams[stream_name] = stream return stream def get_delivery_stream(self, stream_name): if stream_name in self.delivery_streams: return self.delivery_streams[stream_name] else: raise StreamNotFoundError(stream_name) def list_delivery_streams(self): return self.delivery_streams.values() def delete_delivery_stream(self, stream_name): self.delivery_streams.pop(stream_name) def put_firehose_record(self, stream_name, record_data): stream = self.get_delivery_stream(stream_name) record = stream.put_record(record_data) return record def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None): stream = self.describe_stream(stream_name) tags = [] result = { 'HasMoreTags': False, 'Tags': tags } for key, val in sorted(stream.tags.items(), key=lambda x: x[0]): if limit and len(tags) >= limit: result['HasMoreTags'] = True break if exclusive_start_tag_key and key < exclusive_start_tag_key: continue tags.append({ 'Key': key, 'Value': val }) return result def add_tags_to_stream(self, stream_name, tags): stream = self.describe_stream(stream_name) stream.tags.update(tags) def remove_tags_from_stream(self, stream_name, tag_keys): stream = self.describe_stream(stream_name) for key in tag_keys: if key in stream.tags: del stream.tags[key]
class DynamoDBBackend(BaseBackend): def __init__(self): self.tables = OrderedDict() def create_table(self, name, **params): table = Table(name, **params) self.tables[name] = table return table def delete_table(self, name): return self.tables.pop(name, None) def update_table_throughput(self, name, throughput): table = self.tables[name] table.throughput = throughput return table def put_item(self, table_name, item_attrs, expected=None, overwrite=False): table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs, expected, overwrite) def get_table_keys_name(self, table_name, keys): """ Given a set of keys, extracts the key and range key """ table = self.tables.get(table_name) if not table: return None, None else: hash_key = range_key = None for key in keys: if key in table.hash_key_names: hash_key = key elif key in table.range_key_names: range_key = key return hash_key, range_key def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or ( table.has_range_key and table.range_key_attr not in keys): raise ValueError( "Table has a range key, but no range key was passed into get_item" ) hash_key = DynamoType(keys[table.hash_key_attr]) range_key = DynamoType( keys[table.range_key_attr]) if table.has_range_key else None return hash_key, range_key def get_table(self, table_name): return self.tables.get(table_name) def get_item(self, table_name, keys): table = self.get_table(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) range_values = [ DynamoType(range_value) for range_value in range_value_dicts ] return table.query(hash_key, range_comparison, range_values) def scan(self, table_name, filters): table = self.tables.get(table_name) if not table: return None, None, None scan_filters = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) return table.scan(scan_filters) def update_item(self, table_name, key, update_expression): table = self.get_table(table_name) hash_value = DynamoType(key) item = table.get_item(hash_value) item.update(update_expression) return item def delete_item(self, table_name, keys): table = self.tables.get(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.delete_item(hash_key, range_key)
def __init__(self): self.streams = OrderedDict() self.delivery_streams = {}
class DynamoDBBackend(BaseBackend): def __init__(self, region_name=None): self.region_name = region_name self.tables = OrderedDict() def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def create_table(self, name, **params): if name in self.tables: return None table = Table(name, **params) self.tables[name] = table return table def delete_table(self, name): return self.tables.pop(name, None) def tag_resource(self, table_arn, tags): for table in self.tables: if self.tables[table].table_arn == table_arn: self.tables[table].tags.extend(tags) def untag_resource(self, table_arn, tag_keys): for table in self.tables: if self.tables[table].table_arn == table_arn: self.tables[table].tags = [tag for tag in self.tables[table].tags if tag['Key'] not in tag_keys] def list_tags_of_resource(self, table_arn): required_table = None for table in self.tables: if self.tables[table].table_arn == table_arn: required_table = self.tables[table] return required_table.tags def update_table_throughput(self, name, throughput): table = self.tables[name] table.throughput = throughput return table def update_table_streams(self, name, stream_specification): table = self.tables[name] if (stream_specification.get('StreamEnabled') or stream_specification.get('StreamViewType')) and table.latest_stream_label: raise ValueError('Table already has stream enabled') table.set_stream_specification(stream_specification) return table def update_table_global_indexes(self, name, global_index_updates): table = self.tables[name] gsis_by_name = dict((i['IndexName'], i) for i in table.global_indexes) for gsi_update in global_index_updates: gsi_to_create = gsi_update.get('Create') gsi_to_update = gsi_update.get('Update') gsi_to_delete = gsi_update.get('Delete') if gsi_to_delete: index_name = gsi_to_delete['IndexName'] if index_name not in gsis_by_name: raise ValueError('Global Secondary Index does not exist, but tried to delete: %s' % gsi_to_delete['IndexName']) del gsis_by_name[index_name] if gsi_to_update: index_name = gsi_to_update['IndexName'] if index_name not in gsis_by_name: raise ValueError('Global Secondary Index does not exist, but tried to update: %s' % gsi_to_update['IndexName']) gsis_by_name[index_name].update(gsi_to_update) if gsi_to_create: if gsi_to_create['IndexName'] in gsis_by_name: raise ValueError( 'Global Secondary Index already exists: %s' % gsi_to_create['IndexName']) gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create # in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other # parts of the codebase table.global_indexes = list(gsis_by_name.values()) return table def put_item(self, table_name, item_attrs, expected=None, overwrite=False): table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs, expected, overwrite) def get_table_keys_name(self, table_name, keys): """ Given a set of keys, extracts the key and range key """ table = self.tables.get(table_name) if not table: return None, None else: if len(keys) == 1: for key in keys: if key in table.hash_key_names: return key, None # for potential_hash, potential_range in zip(table.hash_key_names, table.range_key_names): # if set([potential_hash, potential_range]) == set(keys): # return potential_hash, potential_range potential_hash, potential_range = None, None for key in set(keys): if key in table.hash_key_names: potential_hash = key elif key in table.range_key_names: potential_range = key return potential_hash, potential_range def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): raise ValueError( "Table has a range key, but no range key was passed into get_item") hash_key = DynamoType(keys[table.hash_key_attr]) range_key = DynamoType( keys[table.range_key_attr]) if table.has_range_key else None return hash_key, range_key def get_table(self, table_name): return self.tables.get(table_name) def get_item(self, table_name, keys): table = self.get_table(table_name) if not table: raise ValueError("No table found") hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, expr_names=None, expr_values=None, filter_expression=None, **filter_kwargs): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) range_values = [DynamoType(range_value) for range_value in range_value_dicts] if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) else: filter_expression = Op(None, None) # Will always eval to true return table.query(hash_key, range_comparison, range_values, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values): table = self.tables.get(table_name) if not table: return None, None, None scan_filters = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) else: filter_expression = Op(None, None) # Will always eval to true return table.scan(scan_filters, limit, exclusive_start_key, filter_expression) def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values, expected=None): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): # Covers cases where table has hash and range keys, ``key`` param # will be a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = DynamoType(key[table.range_key_attr]) elif table.hash_key_attr in key: # Covers tables that have a range key where ``key`` param is a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = None else: # Covers other cases hash_value = DynamoType(key) range_value = None item = table.get_item(hash_value, range_value) if item is None: item_attr = {} elif hasattr(item, 'attrs'): item_attr = item.attrs else: item_attr = item if not expected: expected = {} for key, val in expected.items(): if 'Exists' in val and val['Exists'] is False \ or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': if key in item_attr: raise ValueError("The conditional request failed") elif key not in item_attr: raise ValueError("The conditional request failed") elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value: raise ValueError("The conditional request failed") elif 'ComparisonOperator' in val: dynamo_types = [ DynamoType(ele) for ele in val.get("AttributeValueList", []) ] if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types): raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: data = { table.hash_key_attr: { hash_value.type: hash_value.value, }, } if range_value: data.update({ table.range_key_attr: { range_value.type: range_value.value, } }) table.put_item(data) item = table.get_item(hash_value, range_value) if update_expression: item.update(update_expression, expression_attribute_names, expression_attribute_values) else: item.update_with_attribute_updates(attribute_updates) return item def delete_item(self, table_name, keys): table = self.get_table(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.delete_item(hash_key, range_key) def update_ttl(self, table_name, ttl_spec): table = self.tables.get(table_name) if table is None: raise JsonRESTError('ResourceNotFound', 'Table not found') if 'Enabled' not in ttl_spec or 'AttributeName' not in ttl_spec: raise JsonRESTError('InvalidParameterValue', 'TimeToLiveSpecification does not contain Enabled and AttributeName') if ttl_spec['Enabled']: table.ttl['TimeToLiveStatus'] = 'ENABLED' else: table.ttl['TimeToLiveStatus'] = 'DISABLED' table.ttl['AttributeName'] = ttl_spec['AttributeName'] def describe_ttl(self, table_name): table = self.tables.get(table_name) if table is None: raise JsonRESTError('ResourceNotFound', 'Table not found') return table.ttl
class Shard(BaseModel): def __init__(self, shard_id, starting_hash, ending_hash): self._shard_id = shard_id self.starting_hash = starting_hash self.ending_hash = ending_hash self.records = OrderedDict() @property def shard_id(self): return "shardId-{0}".format(str(self._shard_id).zfill(12)) def get_records(self, last_sequence_id, limit): last_sequence_id = int(last_sequence_id) results = [] secs_behind_latest = 0 for sequence_number, record in self.records.items(): if sequence_number > last_sequence_id: results.append(record) last_sequence_id = sequence_number very_last_record = self.records[next(reversed(self.records))] secs_behind_latest = very_last_record.created_at - record.created_at if len(results) == limit: break millis_behind_latest = int(secs_behind_latest * 1000) return results, last_sequence_id, millis_behind_latest def put_record(self, partition_key, data, explicit_hash_key): # Note: this function is not safe for concurrency if self.records: last_sequence_number = self.get_max_sequence_number() else: last_sequence_number = 0 sequence_number = last_sequence_number + 1 self.records[sequence_number] = Record( partition_key, data, sequence_number, explicit_hash_key) return sequence_number def get_min_sequence_number(self): if self.records: return list(self.records.keys())[0] return 0 def get_max_sequence_number(self): if self.records: return list(self.records.keys())[-1] return 0 def get_sequence_number_at(self, at_timestamp): if not self.records or at_timestamp < list(self.records.values())[0].created_at: return 0 else: # find the last item in the list that was created before # at_timestamp r = next((r for r in reversed(self.records.values()) if r.created_at < at_timestamp), None) return r.sequence_number def to_json(self): return { "HashKeyRange": { "EndingHashKey": str(self.ending_hash), "StartingHashKey": str(self.starting_hash) }, "SequenceNumberRange": { "EndingSequenceNumber": self.get_max_sequence_number(), "StartingSequenceNumber": self.get_min_sequence_number(), }, "ShardId": self.shard_id }
class ELBBackend(BaseBackend): def __init__(self, region_name=None): self.region_name = region_name self.load_balancers = OrderedDict() def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def create_load_balancer(self, name, zones, ports, scheme='internet-facing', subnets=None, security_groups=None): vpc_id = None ec2_backend = ec2_backends[self.region_name] if subnets: subnet = ec2_backend.get_subnet(subnets[0]) vpc_id = subnet.vpc_id if name in self.load_balancers: raise DuplicateLoadBalancerName(name) if not ports: raise EmptyListenersError() if not security_groups: security_groups = [] for security_group in security_groups: if ec2_backend.get_security_group_from_id(security_group) is None: raise InvalidSecurityGroupError() new_load_balancer = FakeLoadBalancer( name=name, zones=zones, ports=ports, scheme=scheme, subnets=subnets, security_groups=security_groups, vpc_id=vpc_id) self.load_balancers[name] = new_load_balancer return new_load_balancer def create_load_balancer_listeners(self, name, ports): balancer = self.load_balancers.get(name, None) if balancer: for port in ports: protocol = port['protocol'] instance_port = port['instance_port'] lb_port = port['load_balancer_port'] ssl_certificate_id = port.get('sslcertificate_id') for listener in balancer.listeners: if lb_port == listener.load_balancer_port: if protocol != listener.protocol: raise DuplicateListenerError(name, lb_port) if instance_port != listener.instance_port: raise DuplicateListenerError(name, lb_port) if ssl_certificate_id != listener.ssl_certificate_id: raise DuplicateListenerError(name, lb_port) break else: balancer.listeners.append(FakeListener( lb_port, instance_port, protocol, ssl_certificate_id)) return balancer def describe_load_balancers(self, names): balancers = self.load_balancers.values() if names: matched_balancers = [ balancer for balancer in balancers if balancer.name in names] if len(names) != len(matched_balancers): missing_elb = list(set(names) - set(matched_balancers))[0] raise LoadBalancerNotFoundError(missing_elb) return matched_balancers else: return balancers def delete_load_balancer_listeners(self, name, ports): balancer = self.load_balancers.get(name, None) listeners = [] if balancer: for lb_port in ports: for listener in balancer.listeners: if int(lb_port) == int(listener.load_balancer_port): continue else: listeners.append(listener) balancer.listeners = listeners return balancer def delete_load_balancer(self, load_balancer_name): self.load_balancers.pop(load_balancer_name, None) def get_load_balancer(self, load_balancer_name): return self.load_balancers.get(load_balancer_name) def apply_security_groups_to_load_balancer(self, load_balancer_name, security_group_ids): load_balancer = self.load_balancers.get(load_balancer_name) ec2_backend = ec2_backends[self.region_name] for security_group_id in security_group_ids: if ec2_backend.get_security_group_from_id(security_group_id) is None: raise InvalidSecurityGroupError() load_balancer.security_groups = security_group_ids def configure_health_check(self, load_balancer_name, timeout, healthy_threshold, unhealthy_threshold, interval, target): check = FakeHealthCheck(timeout, healthy_threshold, unhealthy_threshold, interval, target) load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.health_check = check return check def set_load_balancer_listener_sslcertificate(self, name, lb_port, ssl_certificate_id): balancer = self.load_balancers.get(name, None) if balancer: for idx, listener in enumerate(balancer.listeners): if lb_port == listener.load_balancer_port: balancer.listeners[ idx].ssl_certificate_id = ssl_certificate_id return balancer def register_instances(self, load_balancer_name, instance_ids): load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.instance_ids.extend(instance_ids) return load_balancer def deregister_instances(self, load_balancer_name, instance_ids): load_balancer = self.get_load_balancer(load_balancer_name) new_instance_ids = [ instance_id for instance_id in load_balancer.instance_ids if instance_id not in instance_ids] load_balancer.instance_ids = new_instance_ids return load_balancer def set_cross_zone_load_balancing_attribute(self, load_balancer_name, attribute): load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.attributes.cross_zone_load_balancing = attribute return load_balancer def set_access_log_attribute(self, load_balancer_name, attribute): load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.attributes.access_log = attribute return load_balancer def set_connection_draining_attribute(self, load_balancer_name, attribute): load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.attributes.connection_draining = attribute return load_balancer def set_connection_settings_attribute(self, load_balancer_name, attribute): load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.attributes.connecting_settings = attribute return load_balancer def create_lb_other_policy(self, load_balancer_name, other_policy): load_balancer = self.get_load_balancer(load_balancer_name) if other_policy.policy_name not in [p.policy_name for p in load_balancer.policies.other_policies]: load_balancer.policies.other_policies.append(other_policy) return load_balancer def create_app_cookie_stickiness_policy(self, load_balancer_name, policy): load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.policies.app_cookie_stickiness_policies.append(policy) return load_balancer def create_lb_cookie_stickiness_policy(self, load_balancer_name, policy): load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.policies.lb_cookie_stickiness_policies.append(policy) return load_balancer def set_load_balancer_policies_of_backend_server(self, load_balancer_name, instance_port, policies): load_balancer = self.get_load_balancer(load_balancer_name) backend = [b for b in load_balancer.backends if int( b.instance_port) == instance_port][0] backend_idx = load_balancer.backends.index(backend) backend.policy_names = policies load_balancer.backends[backend_idx] = backend return load_balancer def set_load_balancer_policies_of_listener(self, load_balancer_name, load_balancer_port, policies): load_balancer = self.get_load_balancer(load_balancer_name) listener = [l for l in load_balancer.listeners if int( l.load_balancer_port) == load_balancer_port][0] listener_idx = load_balancer.listeners.index(listener) listener.policy_names = policies load_balancer.listeners[listener_idx] = listener return load_balancer
class DynamoDBBackend(BaseBackend): def __init__(self): self.tables = OrderedDict() def create_table(self, name, **params): if name in self.tables: return None table = Table(name, **params) self.tables[name] = table return table def delete_table(self, name): return self.tables.pop(name, None) def update_table_throughput(self, name, throughput): table = self.tables[name] table.throughput = throughput return table def update_table_global_indexes(self, name, global_index_updates): table = self.tables[name] gsis_by_name = dict((i['IndexName'], i) for i in table.global_indexes) for gsi_update in global_index_updates: gsi_to_create = gsi_update.get('Create') gsi_to_update = gsi_update.get('Update') gsi_to_delete = gsi_update.get('Delete') if gsi_to_delete: index_name = gsi_to_delete['IndexName'] if index_name not in gsis_by_name: raise ValueError('Global Secondary Index does not exist, but tried to delete: %s' % gsi_to_delete['IndexName']) del gsis_by_name[index_name] if gsi_to_update: index_name = gsi_to_update['IndexName'] if index_name not in gsis_by_name: raise ValueError('Global Secondary Index does not exist, but tried to update: %s' % gsi_to_update['IndexName']) gsis_by_name[index_name].update(gsi_to_update) if gsi_to_create: if gsi_to_create['IndexName'] in gsis_by_name: raise ValueError('Global Secondary Index already exists: %s' % gsi_to_create['IndexName']) gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create table.global_indexes = gsis_by_name.values() return table def put_item(self, table_name, item_attrs, expected=None, overwrite=False): table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs, expected, overwrite) def get_table_keys_name(self, table_name, keys): """ Given a set of keys, extracts the key and range key """ table = self.tables.get(table_name) if not table: return None, None else: if len(keys) == 1: for key in keys: if key in table.hash_key_names: return key, None # for potential_hash, potential_range in zip(table.hash_key_names, table.range_key_names): # if set([potential_hash, potential_range]) == set(keys): # return potential_hash, potential_range potential_hash, potential_range = None, None for key in set(keys): if key in table.hash_key_names: potential_hash = key elif key in table.range_key_names: potential_range = key return potential_hash, potential_range def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): raise ValueError("Table has a range key, but no range key was passed into get_item") hash_key = DynamoType(keys[table.hash_key_attr]) range_key = DynamoType(keys[table.range_key_attr]) if table.has_range_key else None return hash_key, range_key def get_table(self, table_name): return self.tables.get(table_name) def get_item(self, table_name, keys): table = self.get_table(table_name) if not table: raise ValueError("No table found") hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, limit, exclusive_start_key, scan_index_forward, index_name=None, **filter_kwargs): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) range_values = [DynamoType(range_value) for range_value in range_value_dicts] return table.query(hash_key, range_comparison, range_values, limit, exclusive_start_key, scan_index_forward, index_name, **filter_kwargs) def scan(self, table_name, filters, limit, exclusive_start_key): table = self.tables.get(table_name) if not table: return None, None, None scan_filters = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) return table.scan(scan_filters, limit, exclusive_start_key) def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): # Covers cases where table has hash and range keys, ``key`` param will be a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = DynamoType(key[table.range_key_attr]) elif table.hash_key_attr in key: # Covers tables that have a range key where ``key`` param is a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = None else: # Covers other cases hash_value = DynamoType(key) range_value = None item = table.get_item(hash_value, range_value) # Update does not fail on new items, so create one if item is None: data = { table.hash_key_attr: { hash_value.type: hash_value.value, }, } if range_value: data.update({ table.range_key_attr: { range_value.type: range_value.value, } }) table.put_item(data) item = table.get_item(hash_value, range_value) if update_expression: item.update(update_expression, expression_attribute_names, expression_attribute_values) else: item.update_with_attribute_updates(attribute_updates) return item def delete_item(self, table_name, keys): table = self.tables.get(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.delete_item(hash_key, range_key)
class DynamoDBBackend(BaseBackend): def __init__(self, region_name=None): self.region_name = region_name self.tables = OrderedDict() def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def create_table(self, name, **params): if name in self.tables: return None table = Table(name, **params) self.tables[name] = table return table def delete_table(self, name): return self.tables.pop(name, None) def tag_resource(self, table_arn, tags): for table in self.tables: if self.tables[table].table_arn == table_arn: self.tables[table].tags.extend(tags) def untag_resource(self, table_arn, tag_keys): for table in self.tables: if self.tables[table].table_arn == table_arn: self.tables[table].tags = [ tag for tag in self.tables[table].tags if tag['Key'] not in tag_keys ] def list_tags_of_resource(self, table_arn): required_table = None for table in self.tables: if self.tables[table].table_arn == table_arn: required_table = self.tables[table] return required_table.tags def update_table_throughput(self, name, throughput): table = self.tables[name] table.throughput = throughput return table def update_table_global_indexes(self, name, global_index_updates): table = self.tables[name] gsis_by_name = dict((i['IndexName'], i) for i in table.global_indexes) for gsi_update in global_index_updates: gsi_to_create = gsi_update.get('Create') gsi_to_update = gsi_update.get('Update') gsi_to_delete = gsi_update.get('Delete') if gsi_to_delete: index_name = gsi_to_delete['IndexName'] if index_name not in gsis_by_name: raise ValueError( 'Global Secondary Index does not exist, but tried to delete: %s' % gsi_to_delete['IndexName']) del gsis_by_name[index_name] if gsi_to_update: index_name = gsi_to_update['IndexName'] if index_name not in gsis_by_name: raise ValueError( 'Global Secondary Index does not exist, but tried to update: %s' % gsi_to_update['IndexName']) gsis_by_name[index_name].update(gsi_to_update) if gsi_to_create: if gsi_to_create['IndexName'] in gsis_by_name: raise ValueError( 'Global Secondary Index already exists: %s' % gsi_to_create['IndexName']) gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create table.global_indexes = gsis_by_name.values() return table def put_item(self, table_name, item_attrs, expected=None, overwrite=False): table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs, expected, overwrite) def get_table_keys_name(self, table_name, keys): """ Given a set of keys, extracts the key and range key """ table = self.tables.get(table_name) if not table: return None, None else: if len(keys) == 1: for key in keys: if key in table.hash_key_names: return key, None # for potential_hash, potential_range in zip(table.hash_key_names, table.range_key_names): # if set([potential_hash, potential_range]) == set(keys): # return potential_hash, potential_range potential_hash, potential_range = None, None for key in set(keys): if key in table.hash_key_names: potential_hash = key elif key in table.range_key_names: potential_range = key return potential_hash, potential_range def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or ( table.has_range_key and table.range_key_attr not in keys): raise ValueError( "Table has a range key, but no range key was passed into get_item" ) hash_key = DynamoType(keys[table.hash_key_attr]) range_key = DynamoType( keys[table.range_key_attr]) if table.has_range_key else None return hash_key, range_key def get_table(self, table_name): return self.tables.get(table_name) def get_item(self, table_name, keys): table = self.get_table(table_name) if not table: raise ValueError("No table found") hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, expr_names=None, expr_values=None, filter_expression=None, **filter_kwargs): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) range_values = [ DynamoType(range_value) for range_value in range_value_dicts ] if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) else: filter_expression = Op(None, None) # Will always eval to true return table.query(hash_key, range_comparison, range_values, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values): table = self.tables.get(table_name) if not table: return None, None, None scan_filters = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) else: filter_expression = Op(None, None) # Will always eval to true return table.scan(scan_filters, limit, exclusive_start_key, filter_expression) def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values, expected=None): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): # Covers cases where table has hash and range keys, ``key`` param # will be a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = DynamoType(key[table.range_key_attr]) elif table.hash_key_attr in key: # Covers tables that have a range key where ``key`` param is a dict hash_value = DynamoType(key[table.hash_key_attr]) range_value = None else: # Covers other cases hash_value = DynamoType(key) range_value = None item = table.get_item(hash_value, range_value) if item is None: item_attr = {} elif hasattr(item, 'attrs'): item_attr = item.attrs else: item_attr = item if not expected: expected = {} for key, val in expected.items(): if 'Exists' in val and val['Exists'] is False: if key in item_attr: raise ValueError("The conditional request failed") elif key not in item_attr: raise ValueError("The conditional request failed") elif 'Value' in val and DynamoType( val['Value']).value != item_attr[key].value: raise ValueError("The conditional request failed") elif 'ComparisonOperator' in val: comparison_func = get_comparison_func( val['ComparisonOperator']) dynamo_types = [ DynamoType(ele) for ele in val["AttributeValueList"] ] for t in dynamo_types: if not comparison_func(item_attr[key].value, t.value): raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: data = { table.hash_key_attr: { hash_value.type: hash_value.value, }, } if range_value: data.update({ table.range_key_attr: { range_value.type: range_value.value, } }) table.put_item(data) item = table.get_item(hash_value, range_value) if update_expression: item.update(update_expression, expression_attribute_names, expression_attribute_values) else: item.update_with_attribute_updates(attribute_updates) return item def delete_item(self, table_name, keys): table = self.get_table(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.delete_item(hash_key, range_key) def update_ttl(self, table_name, ttl_spec): table = self.tables.get(table_name) if table is None: raise JsonRESTError('ResourceNotFound', 'Table not found') if 'Enabled' not in ttl_spec or 'AttributeName' not in ttl_spec: raise JsonRESTError( 'InvalidParameterValue', 'TimeToLiveSpecification does not contain Enabled and AttributeName' ) if ttl_spec['Enabled']: table.ttl['TimeToLiveStatus'] = 'ENABLED' else: table.ttl['TimeToLiveStatus'] = 'DISABLED' table.ttl['AttributeName'] = ttl_spec['AttributeName'] def describe_ttl(self, table_name): table = self.tables.get(table_name) if table is None: raise JsonRESTError('ResourceNotFound', 'Table not found') return table.ttl
class ELBv2Backend(BaseBackend): def __init__(self, region_name=None): self.region_name = region_name self.target_groups = OrderedDict() self.load_balancers = OrderedDict() def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) def create_load_balancer(self, name, security_groups, subnet_ids, scheme='internet-facing'): vpc_id = None ec2_backend = ec2_backends[self.region_name] subnets = [] if not subnet_ids: raise SubnetNotFoundError() for subnet_id in subnet_ids: subnet = ec2_backend.get_subnet(subnet_id) if subnet is None: raise SubnetNotFoundError() subnets.append(subnet) vpc_id = subnets[0].vpc_id arn = "arn:aws:elasticloadbalancing:%s:1:loadbalancer/%s/50dc6c495c0c9188" % ( self.region_name, name) dns_name = "%s-1.%s.elb.amazonaws.com" % (name, self.region_name) if arn in self.load_balancers: raise DuplicateLoadBalancerName() new_load_balancer = FakeLoadBalancer(name=name, security_groups=security_groups, arn=arn, scheme=scheme, subnets=subnets, vpc_id=vpc_id, dns_name=dns_name) self.load_balancers[arn] = new_load_balancer return new_load_balancer def create_rule(self, listener_arn, conditions, priority, actions): listeners = self.describe_listeners(None, [listener_arn]) if not listeners: raise ListenerNotFoundError() listener = listeners[0] # validate conditions for condition in conditions: field = condition['field'] if field not in ['path-pattern', 'host-header']: raise InvalidConditionFieldError(field) values = condition['values'] if len(values) == 0: raise InvalidConditionValueError( 'A condition value must be specified') if len(values) > 1: raise InvalidConditionValueError( "The '%s' field contains too many values; the limit is '1'" % field) # TODO: check pattern of value for 'host-header' # TODO: check pattern of value for 'path-pattern' # validate Priority for rule in listener.rules: if rule.priority == priority: raise PriorityInUseError() # validate Actions target_group_arns = [ target_group.arn for target_group in self.target_groups.values() ] for i, action in enumerate(actions): index = i + 1 action_type = action['type'] if action_type not in ['forward']: raise InvalidActionTypeError(action_type, index) action_target_group_arn = action['target_group_arn'] if action_target_group_arn not in target_group_arns: raise ActionTargetGroupNotFoundError(action_target_group_arn) # TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRules' # create rule rule = FakeRule(listener.arn, conditions, priority, actions, is_default=False) listener.register(rule) return [rule] def create_target_group(self, name, **kwargs): if len(name) > 32: raise InvalidTargetGroupNameError( "Target group name '%s' cannot be longer than '22' characters" % name) if not re.match('^[a-zA-Z0-9\-]+$', name): raise InvalidTargetGroupNameError( "Target group name '%s' can only contain characters that are alphanumeric characters or hyphens(-)" % name) # undocumented validation if not re.match('(?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$', name): raise InvalidTargetGroupNameError( "1 validation error detected: Value '%s' at 'targetGroup.targetGroupArn.targetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern: (?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$" % name) if name.startswith('-') or name.endswith('-'): raise InvalidTargetGroupNameError( "Target group name '%s' cannot begin or end with '-'" % name) for target_group in self.target_groups.values(): if target_group.name == name: raise DuplicateTargetGroupName() arn = "arn:aws:elasticloadbalancing:%s:1:targetgroup/%s/50dc6c495c0c9188" % ( self.region_name, name) target_group = FakeTargetGroup(name, arn, **kwargs) self.target_groups[target_group.arn] = target_group return target_group def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions): balancer = self.load_balancers.get(load_balancer_arn) if balancer is None: raise LoadBalancerNotFoundError() if port in balancer.listeners: raise DuplicateListenerError() arn = load_balancer_arn.replace( ':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self)) listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions) balancer.listeners[listener.arn] = listener return listener def describe_load_balancers(self, arns, names): balancers = self.load_balancers.values() arns = arns or [] names = names or [] if not arns and not names: return balancers matched_balancers = [] matched_balancer = None for arn in arns: for balancer in balancers: if balancer.arn == arn: matched_balancer = balancer if matched_balancer is None: raise LoadBalancerNotFoundError() elif matched_balancer not in matched_balancers: matched_balancers.append(matched_balancer) for name in names: for balancer in balancers: if balancer.name == name: matched_balancer = balancer if matched_balancer is None: raise LoadBalancerNotFoundError() elif matched_balancer not in matched_balancers: matched_balancers.append(matched_balancer) return matched_balancers def describe_rules(self, listener_arn, rule_arns): if listener_arn is None and not rule_arns: raise InvalidDescribeRulesRequest( "You must specify either listener rule ARNs or a listener ARN") if listener_arn is not None and rule_arns is not None: raise InvalidDescribeRulesRequest( 'Listener rule ARNs and a listener ARN cannot be specified at the same time' ) if listener_arn: listener = self.describe_listeners(None, [listener_arn])[0] return listener.rules # search for rule arns matched_rules = [] for load_balancer_arn in self.load_balancers: listeners = self.load_balancers.get( load_balancer_arn).listeners.values() for listener in listeners: for rule in listener.rules: if rule.arn in rule_arns: matched_rules.append(rule) return matched_rules def describe_target_groups(self, load_balancer_arn, target_group_arns, names): if load_balancer_arn: if load_balancer_arn not in self.load_balancers: raise LoadBalancerNotFoundError() return [ tg for tg in self.target_groups.values() if load_balancer_arn in tg.load_balancer_arns ] if target_group_arns: try: return [self.target_groups[arn] for arn in target_group_arns] except KeyError: raise TargetGroupNotFoundError() if names: matched = [] for name in names: found = None for target_group in self.target_groups.values(): if target_group.name == name: found = target_group if not found: raise TargetGroupNotFoundError() matched.append(found) return matched return self.target_groups.values() def describe_listeners(self, load_balancer_arn, listener_arns): if load_balancer_arn: if load_balancer_arn not in self.load_balancers: raise LoadBalancerNotFoundError() return self.load_balancers.get( load_balancer_arn).listeners.values() matched = [] for load_balancer in self.load_balancers.values(): for listener_arn in listener_arns: listener = load_balancer.listeners.get(listener_arn) if not listener: raise ListenerNotFoundError() matched.append(listener) return matched def delete_load_balancer(self, arn): self.load_balancers.pop(arn, None) def delete_rule(self, arn): for load_balancer_arn in self.load_balancers: listeners = self.load_balancers.get( load_balancer_arn).listeners.values() for listener in listeners: for rule in listener.rules: if rule.arn == arn: listener.remove_rule(rule) return # should raise RuleNotFound Error according to the AWS API doc # however, boto3 does't raise error even if rule is not found def delete_target_group(self, target_group_arn): if target_group_arn not in self.target_groups: raise TargetGroupNotFoundError() target_group = self.target_groups[target_group_arn] if target_group: if self._any_listener_using(target_group_arn): raise ResourceInUseError( "The target group '{}' is currently in use by a listener or a rule" .format(target_group_arn)) del self.target_groups[target_group_arn] return target_group def delete_listener(self, listener_arn): for load_balancer in self.load_balancers.values(): listener = load_balancer.listeners.pop(listener_arn, None) if listener: return listener raise ListenerNotFoundError() def modify_rule(self, rule_arn, conditions, actions): # if conditions or actions is empty list, do not update the attributes if not conditions and not actions: raise InvalidModifyRuleArgumentsError() rules = self.describe_rules(listener_arn=None, rule_arns=[rule_arn]) if not rules: raise RuleNotFoundError() rule = rules[0] if conditions: for condition in conditions: field = condition['field'] if field not in ['path-pattern', 'host-header']: raise InvalidConditionFieldError(field) values = condition['values'] if len(values) == 0: raise InvalidConditionValueError( 'A condition value must be specified') if len(values) > 1: raise InvalidConditionValueError( "The '%s' field contains too many values; the limit is '1'" % field) # TODO: check pattern of value for 'host-header' # TODO: check pattern of value for 'path-pattern' # validate Actions target_group_arns = [ target_group.arn for target_group in self.target_groups.values() ] if actions: for i, action in enumerate(actions): index = i + 1 action_type = action['type'] if action_type not in ['forward']: raise InvalidActionTypeError(action_type, index) action_target_group_arn = action['target_group_arn'] if action_target_group_arn not in target_group_arns: raise ActionTargetGroupNotFoundError( action_target_group_arn) # TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRules' # modify rule if conditions: rule.conditions = conditions if actions: rule.actions = actions return [rule] def register_targets(self, target_group_arn, instances): target_group = self.target_groups.get(target_group_arn) if target_group is None: raise TargetGroupNotFoundError() target_group.register(instances) def deregister_targets(self, target_group_arn, instances): target_group = self.target_groups.get(target_group_arn) if target_group is None: raise TargetGroupNotFoundError() target_group.deregister(instances) def describe_target_health(self, target_group_arn, targets): target_group = self.target_groups.get(target_group_arn) if target_group is None: raise TargetGroupNotFoundError() if not targets: targets = target_group.targets.values() return [target_group.health_for(target) for target in targets] def set_rule_priorities(self, rule_priorities): # validate priorities = [ rule_priority['priority'] for rule_priority in rule_priorities ] for priority in set(priorities): if priorities.count(priority) > 1: raise DuplicatePriorityError(priority) # validate for rule_priority in rule_priorities: given_rule_arn = rule_priority['rule_arn'] priority = rule_priority['priority'] _given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn]) if not _given_rules: raise RuleNotFoundError() given_rule = _given_rules[0] listeners = self.describe_listeners(None, [given_rule.listener_arn]) listener = listeners[0] for rule_in_listener in listener.rules: if rule_in_listener.priority == priority: raise PriorityInUseError() # modify modified_rules = [] for rule_priority in rule_priorities: given_rule_arn = rule_priority['rule_arn'] priority = rule_priority['priority'] _given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn]) if not _given_rules: raise RuleNotFoundError() given_rule = _given_rules[0] given_rule.priority = priority modified_rules.append(given_rule) return modified_rules def _any_listener_using(self, target_group_arn): for load_balancer in self.load_balancers.values(): for listener in load_balancer.listeners.values(): for rule in listener.rules: for action in rule.actions: if action.get('target_group_arn') == target_group_arn: return True return False
def __init__(self, shard_id, starting_hash, ending_hash): self._shard_id = shard_id self.starting_hash = starting_hash self.ending_hash = ending_hash self.records = OrderedDict()
class DynamoDBBackend(BaseBackend): def __init__(self): self.tables = OrderedDict() def create_table(self, name, **params): table = Table(name, **params) self.tables[name] = table return table def delete_table(self, name): return self.tables.pop(name, None) def update_table_throughput(self, name, throughput): table = self.tables[name] table.throughput = throughput return table def put_item(self, table_name, item_attrs): table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs) def get_table_keys_name(self, table_name): table = self.tables.get(table_name) if not table: return None, None else: return table.hash_key_attr, table.range_key_attr def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): raise ValueError("Table has a range key, but no range key was passed into get_item") hash_key = DynamoType(keys[table.hash_key_attr]) range_key = DynamoType(keys[table.range_key_attr]) if table.has_range_key else None return hash_key, range_key def get_table(self, table_name): return self.tables.get(table_name) def get_item(self, table_name, keys): table = self.get_table(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) range_values = [DynamoType(range_value) for range_value in range_value_dicts] return table.query(hash_key, range_comparison, range_values) def scan(self, table_name, filters): table = self.tables.get(table_name) if not table: return None, None, None scan_filters = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) return table.scan(scan_filters) def update_item(self, table_name, key, update_expression): table = self.get_table(table_name) hash_value = DynamoType(key) item = table.get_item(hash_value) item.update(update_expression) return item def delete_item(self, table_name, keys): table = self.tables.get(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) return table.delete_item(hash_key, range_key)