def test_list_and_delete_apis(apigateway_client): api_name1 = short_uid() api_name2 = short_uid() response = apigateway_client.create_rest_api(name=api_name1, description="this is my api") api_id = response["id"] apigateway_client.create_rest_api(name=api_name2, description="this is my api2") response = apigateway_client.get_rest_apis() items = [ item for item in response["items"] if item["name"] in [api_name1, api_name2] ] assert len(items) == (2) apigateway_client.delete_rest_api(restApiId=api_id) response = apigateway_client.get_rest_apis() items = [ item for item in response["items"] if item["name"] in [api_name1, api_name2] ] assert len(items) == 1
def associate_vpc_with_hosted_zone( self, context: RequestContext, hosted_zone_id: ResourceId, vpc: VPC, comment: AssociateVPCComment = None, ) -> AssociateVPCWithHostedZoneResponse: region_details = Route53Backend.get() # TODO: handle NoSuchHostedZone and ConflictingDomainExist zone_details = region_details.vpc_hosted_zone_associations.get( hosted_zone_id) or [] hosted_zone_association = HostedZoneAssociation( hosted_zone_id=hosted_zone_id, id=short_uid(), vpc=vpc, status=ChangeStatus.INSYNC, submitted_at=datetime.now(), ) zone_details.append(hosted_zone_association) vpc_id = vpc.get("VPCId") # update VPC info in hosted zone moto object - fixes required after https://github.com/spulec/moto/pull/4786 hosted_zone = route53_backend.zones.get(hosted_zone_id) if not getattr(hosted_zone, "vpcid", None): hosted_zone.vpcid = vpc_id if not getattr(hosted_zone, "vpcregion", None): hosted_zone.vpcregion = aws_stack.get_region() region_details.vpc_hosted_zone_associations[ hosted_zone_id] = zone_details return AssociateVPCWithHostedZoneResponse( ChangeInfo=ChangeInfo(Id=short_uid(), Status=ChangeStatus.INSYNC, SubmittedAt=datetime.now()))
def test_object_created_and_object_removed( self, s3_client, sqs_client, s3_create_bucket, sqs_create_queue, s3_create_sqs_bucket_notification, snapshot, ): snapshot.add_transformer(snapshot.transform.sqs_api()) snapshot.add_transformer(snapshot.transform.s3_api()) snapshot.add_transformer( snapshot.transform.jsonpath("$..s3.object.key", "object-key")) # setup fixture bucket_name = s3_create_bucket() queue_url = sqs_create_queue() s3_create_sqs_bucket_notification( bucket_name, queue_url, ["s3:ObjectCreated:*", "s3:ObjectRemoved:*"]) src_key = "src-dest-%s" % short_uid() dest_key = "key-dest-%s" % short_uid() # event0 = PutObject s3_client.put_object(Bucket=bucket_name, Key=src_key, Body="something") # event1 = CopyObject s3_client.copy_object( Bucket=bucket_name, CopySource={ "Bucket": bucket_name, "Key": src_key }, Key=dest_key, ) # event3 = DeleteObject s3_client.delete_object(Bucket=bucket_name, Key=src_key) # collect events events = sqs_collect_s3_events(sqs_client, queue_url, 3) assert len(events) == 3, f"unexpected number of events in {events}" # order seems not be guaranteed - sort so we can rely on the order events.sort(key=lambda x: x["eventName"]) snapshot.match("receive_messages", {"messages": events}) assert events[1]["eventName"] == "ObjectCreated:Put" assert events[1]["s3"]["bucket"]["name"] == bucket_name assert events[1]["s3"]["object"]["key"] == src_key assert events[0]["eventName"] == "ObjectCreated:Copy" assert events[0]["s3"]["bucket"]["name"] == bucket_name assert events[0]["s3"]["object"]["key"] == dest_key assert events[2]["eventName"] == "ObjectRemoved:Delete" assert events[2]["s3"]["bucket"]["name"] == bucket_name assert events[2]["s3"]["object"]["key"] == src_key
def requests_error_response_xml_signature_calculation( message, string_to_sign=None, signature=None, expires=None, code=400, code_string="AccessDenied", aws_access_token="temp", ): response = RequestsResponse() response_template = """<?xml version="1.0" encoding="UTF-8"?> <Error> <Code>{code_string}</Code> <Message>{message}</Message> <RequestId>{req_id}</RequestId> <HostId>{host_id}</HostId> </Error>""".format( message=message, code_string=code_string, req_id=short_uid(), host_id=short_uid(), ) parsed_response = xmltodict.parse(response_template) response.status_code = code if signature and string_to_sign or code_string == "SignatureDoesNotMatch": bytes_signature = binascii.hexlify(bytes(signature, encoding="utf-8")) parsed_response["Error"]["Code"] = code_string parsed_response["Error"]["AWSAccessKeyId"] = aws_access_token parsed_response["Error"]["StringToSign"] = string_to_sign parsed_response["Error"]["SignatureProvided"] = signature parsed_response["Error"]["StringToSignBytes"] = "{}".format( bytes_signature.decode("utf-8")) set_response_content(response, xmltodict.unparse(parsed_response)) if expires and code_string == "AccessDenied": server_time = datetime.datetime.utcnow().isoformat()[:-4] expires_isoformat = datetime.datetime.fromtimestamp( int(expires)).isoformat()[:-4] parsed_response["Error"]["Code"] = code_string parsed_response["Error"]["Expires"] = "{}Z".format(expires_isoformat) parsed_response["Error"]["ServerTime"] = "{}Z".format(server_time) set_response_content(response, xmltodict.unparse(parsed_response)) if not signature and not expires and code_string == "AccessDenied": set_response_content(response, xmltodict.unparse(parsed_response)) if response._content: return response
def _generate_machine_id() -> str: if config.is_in_docker: return short_uid() # this can potentially be useful when generated on the host using the CLI and then mounted into the container via # machine.json try: if os.path.exists("/etc/machine-id"): with open("/etc/machine-id") as fd: return md5(str(fd.read()))[:8] except Exception: pass # always fall back to short_uid() return short_uid()
def __init__( self, method, path, data, headers, api_id=None, stage=None, context=None, auth_info=None, ): self.method = method self.path = path self.data = data self.headers = headers self.context = { "requestId": short_uid() } if context is None else context self.auth_info = {} if auth_info is None else auth_info self.apigw_version = None self.api_id = api_id self.stage = stage self.region_name = None self.integration = None self.resource = None self.resource_path = None self.path_with_query_string = None self.response_templates = {} self.stage_variables = {} self.path_params = {} self.ws_route = None
def __init__(self, metadata=None, template=None): if template is None: template = {} self.metadata = metadata or {} self.template = template or {} self._template_raw = clone_safe(self.template) self.template_original = clone_safe(self.template) # initialize resources for resource_id, resource in self.template_resources.items(): resource["LogicalResourceId"] = self.template_original["Resources"][resource_id][ "LogicalResourceId" ] = (resource.get("LogicalResourceId") or resource_id) # initialize stack template attributes stack_id = self.metadata.get("StackId") or aws_stack.cloudformation_stack_arn( self.stack_name, short_uid() ) self.template["StackId"] = self.metadata["StackId"] = stack_id self.template["Parameters"] = self.template.get("Parameters") or {} self.template["Outputs"] = self.template.get("Outputs") or {} self.template["Conditions"] = self.template.get("Conditions") or {} # initialize metadata self.metadata["Parameters"] = self.metadata.get("Parameters") or [] self.metadata["StackStatus"] = "CREATE_IN_PROGRESS" self.metadata["CreationTime"] = self.metadata.get("CreationTime") or timestamp_millis() # maps resource id to resource state self._resource_states = {} # list of stack events self.events = [] # list of stack change sets self.change_sets = []
def get_parameters_for_import( self, context: RequestContext, key_id: KeyIdType, wrapping_algorithm: AlgorithmSpec, wrapping_key_spec: WrappingKeySpec, ) -> GetParametersForImportResponse: key = _generate_data_key_pair({"KeySpec": wrapping_key_spec}, create_cipher=False, add_to_keys=False) import_token = short_uid() import_state = KeyImportState( key_id=key_id, import_token=import_token, private_key=key["PrivateKeyPlaintext"], public_key=key["PublicKey"], wrapping_algo=wrapping_algorithm, key_obj=key["_key_"], ) KMSBackend.get().imports[import_token] = import_state expiry_date = datetime.datetime.now() + datetime.timedelta(days=100) return GetParametersForImportResponse( KeyId=key_id, ImportToken=to_bytes(import_state.import_token), PublicKey=import_state.public_key, ParametersValidTo=expiry_date, )
def create_stack_instances( self, context: RequestContext, request: CreateStackInstancesInput, ) -> CreateStackInstancesOutput: state = CloudFormationRegion.get() set_name = request.get("StackSetName") stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name] if not stack_set: return not_found_error(f'Stack set named "{set_name}" does not exist') stack_set = stack_set[0] op_id = request.get("OperationId") or short_uid() sset_meta = stack_set.metadata accounts = request["Accounts"] regions = request["Regions"] stacks_to_await = [] for account in accounts: for region in regions: # deploy new stack LOG.debug('Deploying instance for stack set "%s" in region "%s"', set_name, region) cf_client = aws_stack.connect_to_service("cloudformation", region_name=region) kwargs = select_attributes(sset_meta, ["TemplateBody"]) or select_attributes( sset_meta, ["TemplateURL"] ) stack_name = f"sset-{set_name}-{account}" result = cf_client.create_stack(StackName=stack_name, **kwargs) stacks_to_await.append((stack_name, region)) # store stack instance instance = { "StackSetId": sset_meta["StackSetId"], "OperationId": op_id, "Account": account, "Region": region, "StackId": result["StackId"], "Status": "CURRENT", "StackInstanceStatus": {"DetailedStatus": "SUCCEEDED"}, } instance = StackInstance(instance) stack_set.stack_instances.append(instance) # wait for completion of stack for stack in stacks_to_await: aws_stack.await_stack_completion(stack[0], region_name=stack[1]) # record operation operation = { "OperationId": op_id, "StackSetId": stack_set.metadata["StackSetId"], "Action": "CREATE", "Status": "SUCCEEDED", } stack_set.operations[op_id] = operation return CreateStackInstancesOutput(OperationId=op_id)
def test_object_created_copy( self, s3_client, sqs_client, s3_create_bucket, sqs_create_queue, s3_create_sqs_bucket_notification, snapshot, ): snapshot.add_transformer(snapshot.transform.sqs_api()) snapshot.add_transformer(snapshot.transform.s3_api()) snapshot.add_transformer( snapshot.transform.jsonpath("$..s3.object.key", "object-key")) # setup fixture bucket_name = s3_create_bucket() queue_url = sqs_create_queue() s3_create_sqs_bucket_notification(bucket_name, queue_url, ["s3:ObjectCreated:Copy"]) src_key = "src-dest-%s" % short_uid() dest_key = "key-dest-%s" % short_uid() s3_client.put_object(Bucket=bucket_name, Key=src_key, Body="something") assert not sqs_collect_s3_events( sqs_client, queue_url, 0, timeout=1), "unexpected event triggered for put_object" s3_client.copy_object( Bucket=bucket_name, CopySource={ "Bucket": bucket_name, "Key": src_key }, Key=dest_key, ) events = sqs_collect_s3_events(sqs_client, queue_url, 1) assert len(events) == 1, f"unexpected number of events in {events}" snapshot.match("receive_messages", {"messages": events}) assert events[0]["eventSource"] == "aws:s3" assert events[0]["eventName"] == "ObjectCreated:Copy" assert events[0]["s3"]["bucket"]["name"] == bucket_name assert events[0]["s3"]["object"]["key"] == dest_key
def test_object_tagging_delete_event( self, s3_client, sqs_client, s3_create_bucket, sqs_create_queue, s3_create_sqs_bucket_notification, snapshot, ): snapshot.add_transformer(snapshot.transform.sqs_api()) snapshot.add_transformer(snapshot.transform.s3_api()) snapshot.add_transformer( snapshot.transform.jsonpath("$..s3.object.key", "object-key")) # setup fixture bucket_name = s3_create_bucket() queue_url = sqs_create_queue() s3_create_sqs_bucket_notification(bucket_name, queue_url, ["s3:ObjectTagging:Delete"]) dest_key = "key-dest-%s" % short_uid() s3_client.put_object(Bucket=bucket_name, Key=dest_key, Body="FooBarBlitz") assert not sqs_collect_s3_events( sqs_client, queue_url, 0, timeout=1), "unexpected event triggered for put_object" s3_client.put_object_tagging( Bucket=bucket_name, Key=dest_key, Tagging={ "TagSet": [ { "Key": "swallow_type", "Value": "african" }, ] }, ) s3_client.delete_object_tagging( Bucket=bucket_name, Key=dest_key, ) events = sqs_collect_s3_events(sqs_client, queue_url, 1) assert len(events) == 1, f"unexpected number of events in {events}" snapshot.match("receive_messages", {"messages": events}) assert events[0]["eventSource"] == "aws:s3" assert events[0]["eventName"] == "ObjectTagging:Delete" assert events[0]["s3"]["bucket"]["name"] == bucket_name assert events[0]["s3"]["object"]["key"] == dest_key
def create_stack_set( self, context: RequestContext, request: CreateStackSetInput ) -> CreateStackSetOutput: state = CloudFormationRegion.get() stack_set = StackSet(request) stack_set_id = short_uid() stack_set.metadata["StackSetId"] = stack_set_id state.stack_sets[stack_set_id] = stack_set return CreateStackSetOutput(StackSetId=stack_set_id)
def create_vpc_link( self, context: RequestContext, name: String, target_arns: ListOfString, description: String = None, tags: MapOfStringToString = None, ) -> VpcLink: region_details = APIGatewayRegion.get() link_id = short_uid() entry = {"id": link_id, "status": "AVAILABLE"} region_details.vpc_links[link_id] = entry result = to_vpc_link_response_json(entry) return VpcLink(**result)
def events_handler_put_events(self): entries = self._get_param("Entries") # keep track of events for local integration testing if config.is_local_test_mode(): TEST_EVENTS_CACHE.extend(entries) events = list(map(lambda event: {"event": event, "uuid": str(long_uid())}, entries)) _dump_events_to_files(events) event_rules = self.events_backend.rules for event_envelope in events: event = event_envelope["event"] event_bus = event.get("EventBusName") or DEFAULT_EVENT_BUS_NAME matchine_rules = [r for r in event_rules.values() if r.event_bus_name == event_bus] if not matchine_rules: continue formatted_event = { "version": "0", "id": event_envelope["uuid"], "detail-type": event.get("DetailType"), "source": event.get("Source"), "account": TEST_AWS_ACCOUNT_ID, "time": datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"), "region": self.region, "resources": event.get("Resources", []), "detail": json.loads(event.get("Detail", "{}")), } targets = [] for rule in matchine_rules: if filter_event_based_on_event_format(self, rule.name, formatted_event): targets.extend(self.events_backend.list_targets_by_rule(rule.name)["Targets"]) # process event process_events(formatted_event, targets) content = { "FailedEntryCount": 0, # TODO: dynamically set proper value when refactoring "Entries": list(map(lambda event: {"EventId": event["uuid"]}, events)), } self.response_headers.update( {"Content-Type": APPLICATION_AMZ_JSON_1_1, "x-amzn-RequestId": short_uid()} ) return json.dumps(content), self.response_headers
def create_authorizer( self, context: RequestContext, request: CreateAuthorizerRequest ) -> Authorizer: region_details = APIGatewayRegion.get() api_id = request["restApiId"] authorizer_id = short_uid()[:6] # length 6 to make TF tests pass result = deepcopy(request) result["id"] = authorizer_id result = normalize_authorizer(result) region_details.authorizers.setdefault(api_id, []).append(result) result = to_authorizer_response_json(api_id, result) return Authorizer(**result)
def __init__(self, params=None, template=None): if template is None: template = {} if params is None: params = {} super(StackChangeSet, self).__init__(params, template) name = self.metadata["ChangeSetName"] if not self.metadata.get("ChangeSetId"): self.metadata["ChangeSetId"] = aws_stack.cf_change_set_arn( name, change_set_id=short_uid() ) stack = self.stack = find_stack(self.metadata["StackName"]) self.metadata["StackId"] = stack.stack_id self.metadata["Status"] = "CREATE_PENDING"
def generate_client_certificate( self, context: RequestContext, description: String = None, tags: MapOfStringToString = None ) -> ClientCertificate: region_details = APIGatewayRegion.get() cert_id = short_uid() creation_time = now_utc() entry = { "description": description, "tags": tags, "clientCertificateId": cert_id, "createdDate": creation_time, "expirationDate": creation_time + 60 * 60 * 24 * 30, # assume 30 days validity "pemEncodedCertificate": "testcert-123", # TODO return proper certificate! } region_details.client_certificates[cert_id] = entry result = to_client_cert_response_json(entry) return ClientCertificate(**result)
def get_stream_info( stream_name, log_file=None, shards=None, env=None, endpoint_url=None, ddb_lease_table_suffix=None, env_vars=None, ): if env_vars is None: env_vars = {} if not ddb_lease_table_suffix: ddb_lease_table_suffix = DEFAULT_DDB_LEASE_TABLE_SUFFIX # construct stream info env = aws_stack.get_environment(env) props_file = os.path.join(tempfile.gettempdir(), "kclipy.%s.properties" % short_uid()) # make sure to convert stream ARN to stream name stream_name = aws_stack.kinesis_stream_name(stream_name) app_name = "%s%s" % (stream_name, ddb_lease_table_suffix) stream_info = { "name": stream_name, "region": aws_stack.get_region(), "shards": shards, "properties_file": props_file, "log_file": log_file, "app_name": app_name, "env_vars": env_vars, } # set local connection if aws_stack.is_local_env(env): stream_info["conn_kwargs"] = { "host": LOCALHOST, "port": config.service_port("kinesis"), "is_secure": bool(config.USE_SSL), } if endpoint_url: if "conn_kwargs" not in stream_info: stream_info["conn_kwargs"] = {} url = urlparse(endpoint_url) stream_info["conn_kwargs"]["host"] = url.hostname stream_info["conn_kwargs"]["port"] = url.port stream_info["conn_kwargs"]["is_secure"] = url.scheme == "https" return stream_info
def update_stack_set( self, context: RequestContext, request: UpdateStackSetInput ) -> UpdateStackSetOutput: state = CloudFormationRegion.get() set_name = request.get("StackSetName") stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name] if not stack_set: return not_found_error(f'Stack set named "{set_name}" does not exist') stack_set = stack_set[0] stack_set.metadata.update(request) op_id = request.get("OperationId") or short_uid() operation = { "OperationId": op_id, "StackSetId": stack_set.metadata["StackSetId"], "Action": "UPDATE", "Status": "SUCCEEDED", } stack_set.operations[op_id] = operation return UpdateStackSetOutput(OperationId=op_id)
def requests_error_response_xml( message: str, code: Optional[int] = 400, code_string: Optional[str] = "InvalidParameter", service: Optional[str] = None, xmlns: Optional[str] = None, ): response = RequestsResponse() xmlns = xmlns or "http://%s.amazonaws.com/doc/2010-03-31/" % service response._content = """<ErrorResponse xmlns="{xmlns}"><Error> <Type>Sender</Type> <Code>{code_string}</Code> <Message>{message}</Message> </Error><RequestId>{req_id}</RequestId> </ErrorResponse>""".format(xmlns=xmlns, message=message, code_string=code_string, req_id=short_uid()) response.status_code = code return response
def create_documentation_part( self, context: RequestContext, rest_api_id: String, location: DocumentationPartLocation, properties: String, ) -> DocumentationPart: region_details = APIGatewayRegion.get() entity_id = short_uid()[:6] # length 6 for AWS parity / Terraform compatibility entry = { "id": entity_id, "restApiId": rest_api_id, "location": location, "properties": properties, } region_details.documentation_parts.setdefault(rest_api_id, []).append(entry) result = to_documentation_part_response_json(rest_api_id, entry) return DocumentationPart(**result)
def create_request_validator( self, context: RequestContext, rest_api_id: String, name: String = None, validate_request_body: Boolean = None, validate_request_parameters: Boolean = None, ) -> RequestValidator: region_details = APIGatewayRegion.get() # length 6 for AWS parity and TF compatibility validator_id = short_uid()[:6] entry = { "id": validator_id, "name": name, "restApiId": rest_api_id, "validateRequestBody": validate_request_body, "validateRequestParameters": validate_request_parameters, } region_details.validators.setdefault(rest_api_id, []).append(entry) return RequestValidator(**entry)
def render_velocity_template(template, context, variables=None, as_json=False): if variables is None: variables = {} if not template: return template # fix "#set" commands template = re.sub(r"(^|\n)#\s+set(.*)", r"\1#set\2", template, re.MULTILINE) # enable syntax like "test#${foo.bar}" empty_placeholder = " __pLaCe-HoLdEr__ " template = re.sub( r"([^\s]+)#\$({)?(.*)", r"\1#%s$\2\3" % empty_placeholder, template, re.MULTILINE, ) # add extensions for common string functions below class ExtendedString(str): def trim(self, *args, **kwargs): return ExtendedString(self.strip(*args, **kwargs)) def toLowerCase(self, *args, **kwargs): return ExtendedString(self.lower(*args, **kwargs)) def toUpperCase(self, *args, **kwargs): return ExtendedString(self.upper(*args, **kwargs)) def apply(obj, **kwargs): if isinstance(obj, dict): for k, v in obj.items(): if isinstance(v, str): obj[k] = ExtendedString(v) return obj # loop through the variables and enable certain additional util functions (e.g., string utils) variables = variables or {} recurse_object(variables, apply) # prepare and render template context_var = variables.get("context") or {} context_var.setdefault("requestId", short_uid()) t = airspeed.Template(template) var_map = { "input": VelocityInput(context), "util": VelocityUtil(), "context": context_var, } var_map.update(variables or {}) replaced = t.merge(var_map) # revert temporary changes from the fixes above replaced = replaced.replace(empty_placeholder, "") if as_json: replaced = json.loads(replaced) return replaced
def mountable_tmp_file(): f = os.path.join(config.dirs.tmp, short_uid()) TMP_FILES.append(f) return f
def start_kcl_client_process( stream_name, listener_script, log_file=None, env=None, configs=None, endpoint_url=None, ddb_lease_table_suffix=None, env_vars=None, region_name=None, kcl_log_level=DEFAULT_KCL_LOG_LEVEL, log_subscribers=None, ): if configs is None: configs = {} if env_vars is None: env_vars = {} if log_subscribers is None: log_subscribers = [] env = aws_stack.get_environment(env) # make sure to convert stream ARN to stream name stream_name = aws_stack.kinesis_stream_name(stream_name) # decide which credentials provider to use credentialsProvider = None if ("AWS_ASSUME_ROLE_ARN" in os.environ or "AWS_ASSUME_ROLE_ARN" in env_vars) and ("AWS_ASSUME_ROLE_SESSION_NAME" in os.environ or "AWS_ASSUME_ROLE_SESSION_NAME" in env_vars): # use special credentials provider that can assume IAM roles and handle temporary STS auth tokens credentialsProvider = "cloud.localstack.DefaultSTSAssumeRoleSessionCredentialsProvider" # pass through env variables to child process for var_name in [ "AWS_ASSUME_ROLE_ARN", "AWS_ASSUME_ROLE_SESSION_NAME", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", ]: if var_name in os.environ and var_name not in env_vars: env_vars[var_name] = os.environ[var_name] if aws_stack.is_local_env(env): # need to disable CBOR protocol, enforce use of plain JSON, # see https://github.com/mhart/kinesalite/issues/31 env_vars["AWS_CBOR_DISABLE"] = "true" if kcl_log_level or (len(log_subscribers) > 0): if not log_file: log_file = LOG_FILE_PATTERN.replace("*", short_uid()) TMP_FILES.append(log_file) run("touch %s" % log_file) # start log output reader thread which will read the KCL log # file and print each line to stdout of this process... reader_thread = OutputReaderThread({ "file": log_file, "level": kcl_log_level, "log_prefix": "KCL", "log_subscribers": log_subscribers, }) reader_thread.start() # construct stream info stream_info = get_stream_info( stream_name, log_file, env=env, endpoint_url=endpoint_url, ddb_lease_table_suffix=ddb_lease_table_suffix, env_vars=env_vars, ) props_file = stream_info["properties_file"] # set kcl config options kwargs = {"metricsLevel": "NONE", "initialPositionInStream": "LATEST"} # set parameters for local connection if aws_stack.is_local_env(env): kwargs[ "kinesisEndpoint"] = f"{LOCALHOST}:{config.service_port('kinesis')}" kwargs[ "dynamodbEndpoint"] = f"{LOCALHOST}:{config.service_port('dynamodb')}" kwargs["kinesisProtocol"] = config.get_protocol() kwargs["dynamodbProtocol"] = config.get_protocol() kwargs["disableCertChecking"] = "true" kwargs.update(configs) # create config file kclipy_helper.create_config_file( config_file=props_file, executableName=listener_script, streamName=stream_name, applicationName=stream_info["app_name"], credentialsProvider=credentialsProvider, region_name=region_name, **kwargs, ) TMP_FILES.append(props_file) # start stream consumer stream = KinesisStream(id=stream_name, params=stream_info) thread_consumer = KinesisProcessorThread.start_consumer(stream) TMP_THREADS.append(thread_consumer) return thread_consumer
def listen_to_kinesis( stream_name, listener_func=None, processor_script=None, events_file=None, endpoint_url=None, log_file=None, configs=None, env=None, ddb_lease_table_suffix=None, env_vars=None, kcl_log_level=DEFAULT_KCL_LOG_LEVEL, log_subscribers=None, wait_until_started=False, fh_d_stream=None, region_name=None, ): """ High-level function that allows to subscribe to a Kinesis stream and receive events in a listener function. A KCL client process is automatically started in the background. """ if configs is None: configs = {} if env_vars is None: env_vars = {} if log_subscribers is None: log_subscribers = [] env = aws_stack.get_environment(env) if not events_file: events_file = EVENTS_FILE_PATTERN.replace("*", short_uid()) TMP_FILES.append(events_file) if not processor_script: processor_script = generate_processor_script(events_file, log_file=log_file) rm_rf(events_file) # start event reader thread (this process) ready_mutex = threading.Semaphore(0) thread = EventFileReaderThread(events_file, listener_func, ready_mutex=ready_mutex, fh_d_stream=fh_d_stream) thread.start() # Wait until the event reader thread is ready (to avoid 'Connection refused' error on the UNIX socket) ready_mutex.acquire() # start KCL client (background process) if processor_script[-4:] == ".pyc": processor_script = processor_script[0:-1] # add log listener that notifies when KCL is started if wait_until_started: listener = KclStartedLogListener() log_subscribers.append(listener) process = start_kcl_client_process( stream_name, processor_script, endpoint_url=endpoint_url, log_file=log_file, configs=configs, env=env, ddb_lease_table_suffix=ddb_lease_table_suffix, env_vars=env_vars, kcl_log_level=kcl_log_level, log_subscribers=log_subscribers, region_name=region_name, ) if wait_until_started: # Wait at most 90 seconds for initialization. Note that creating the DDB table can take quite a bit try: listener.sync_init.get(block=True, timeout=90) except Exception: raise Exception("Timeout when waiting for KCL initialization.") # wait at most 30 seconds for shard lease notification try: listener.sync_take_shard.get(block=True, timeout=30) except Exception: # this merely means that there is no shard available to take. Do nothing. pass return process
def generate_processor_script(events_file, log_file=None): script_file = os.path.join(tempfile.gettempdir(), "kclipy.%s.processor.py" % short_uid()) if log_file: log_file = "'%s'" % log_file else: log_file = "None" content = """#!/usr/bin/env python import os, sys, glob, json, socket, time, logging, subprocess, tempfile logging.basicConfig(level=logging.INFO) for path in glob.glob('%s/lib/python*/site-packages'): sys.path.insert(0, path) sys.path.insert(0, '%s') from localstack.config import DEFAULT_ENCODING from localstack.utils.kinesis import kinesis_connector from localstack.utils.time import timestamp events_file = '%s' log_file = %s error_log = os.path.join(tempfile.gettempdir(), 'kclipy.error.log') if __name__ == '__main__': sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) num_tries = 3 sleep_time = 2 error = None for i in range(0, num_tries): try: sock.connect(events_file) error = None break except Exception as e: error = e if i < num_tries: msg = '%%s: Unable to connect to UNIX socket. Retrying.' %% timestamp() subprocess.check_output('echo "%%s" >> %%s' %% (msg, error_log), shell=True) time.sleep(sleep_time) if error: print("WARN: Unable to connect to UNIX socket after retrying: %%s" %% error) raise error def receive_msg(records, checkpointer, shard_id): try: # records is a list of amazon_kclpy.messages.Record objects -> convert to JSON records_dicts = [j._json_dict for j in records] message_to_send = {'shard_id': shard_id, 'records': records_dicts} string_to_send = '%%s\\n' %% json.dumps(message_to_send) bytes_to_send = string_to_send.encode(DEFAULT_ENCODING) sock.send(bytes_to_send) except Exception as e: msg = "WARN: Unable to forward event: %%s" %% e print(msg) subprocess.check_output('echo "%%s" >> %%s' %% (msg, error_log), shell=True) kinesis_connector.KinesisProcessor.run_processor(log_file=log_file, processor_func=receive_msg) """ % ( LOCALSTACK_VENV_FOLDER, LOCALSTACK_ROOT_FOLDER, events_file, log_file, ) save_file(script_file, content) chmod_r(script_file, 0o755) TMP_FILES.append(script_file) return script_file