示例#1
0
def test_list_and_delete_apis(apigateway_client):
    api_name1 = short_uid()
    api_name2 = short_uid()

    response = apigateway_client.create_rest_api(name=api_name1,
                                                 description="this is my api")
    api_id = response["id"]
    apigateway_client.create_rest_api(name=api_name2,
                                      description="this is my api2")

    response = apigateway_client.get_rest_apis()
    items = [
        item for item in response["items"]
        if item["name"] in [api_name1, api_name2]
    ]
    assert len(items) == (2)

    apigateway_client.delete_rest_api(restApiId=api_id)

    response = apigateway_client.get_rest_apis()
    items = [
        item for item in response["items"]
        if item["name"] in [api_name1, api_name2]
    ]
    assert len(items) == 1
示例#2
0
    def associate_vpc_with_hosted_zone(
        self,
        context: RequestContext,
        hosted_zone_id: ResourceId,
        vpc: VPC,
        comment: AssociateVPCComment = None,
    ) -> AssociateVPCWithHostedZoneResponse:
        region_details = Route53Backend.get()
        # TODO: handle NoSuchHostedZone and ConflictingDomainExist
        zone_details = region_details.vpc_hosted_zone_associations.get(
            hosted_zone_id) or []
        hosted_zone_association = HostedZoneAssociation(
            hosted_zone_id=hosted_zone_id,
            id=short_uid(),
            vpc=vpc,
            status=ChangeStatus.INSYNC,
            submitted_at=datetime.now(),
        )
        zone_details.append(hosted_zone_association)
        vpc_id = vpc.get("VPCId")
        # update VPC info in hosted zone moto object - fixes required after https://github.com/spulec/moto/pull/4786
        hosted_zone = route53_backend.zones.get(hosted_zone_id)
        if not getattr(hosted_zone, "vpcid", None):
            hosted_zone.vpcid = vpc_id
        if not getattr(hosted_zone, "vpcregion", None):
            hosted_zone.vpcregion = aws_stack.get_region()

        region_details.vpc_hosted_zone_associations[
            hosted_zone_id] = zone_details
        return AssociateVPCWithHostedZoneResponse(
            ChangeInfo=ChangeInfo(Id=short_uid(),
                                  Status=ChangeStatus.INSYNC,
                                  SubmittedAt=datetime.now()))
示例#3
0
    def test_object_created_and_object_removed(
        self,
        s3_client,
        sqs_client,
        s3_create_bucket,
        sqs_create_queue,
        s3_create_sqs_bucket_notification,
        snapshot,
    ):
        snapshot.add_transformer(snapshot.transform.sqs_api())
        snapshot.add_transformer(snapshot.transform.s3_api())
        snapshot.add_transformer(
            snapshot.transform.jsonpath("$..s3.object.key", "object-key"))

        # setup fixture
        bucket_name = s3_create_bucket()
        queue_url = sqs_create_queue()
        s3_create_sqs_bucket_notification(
            bucket_name, queue_url,
            ["s3:ObjectCreated:*", "s3:ObjectRemoved:*"])

        src_key = "src-dest-%s" % short_uid()
        dest_key = "key-dest-%s" % short_uid()

        # event0 = PutObject
        s3_client.put_object(Bucket=bucket_name, Key=src_key, Body="something")
        # event1 = CopyObject
        s3_client.copy_object(
            Bucket=bucket_name,
            CopySource={
                "Bucket": bucket_name,
                "Key": src_key
            },
            Key=dest_key,
        )
        # event3 = DeleteObject
        s3_client.delete_object(Bucket=bucket_name, Key=src_key)

        # collect events
        events = sqs_collect_s3_events(sqs_client, queue_url, 3)
        assert len(events) == 3, f"unexpected number of events in {events}"

        # order seems not be guaranteed - sort so we can rely on the order
        events.sort(key=lambda x: x["eventName"])

        snapshot.match("receive_messages", {"messages": events})

        assert events[1]["eventName"] == "ObjectCreated:Put"
        assert events[1]["s3"]["bucket"]["name"] == bucket_name
        assert events[1]["s3"]["object"]["key"] == src_key

        assert events[0]["eventName"] == "ObjectCreated:Copy"
        assert events[0]["s3"]["bucket"]["name"] == bucket_name
        assert events[0]["s3"]["object"]["key"] == dest_key

        assert events[2]["eventName"] == "ObjectRemoved:Delete"
        assert events[2]["s3"]["bucket"]["name"] == bucket_name
        assert events[2]["s3"]["object"]["key"] == src_key
示例#4
0
def requests_error_response_xml_signature_calculation(
    message,
    string_to_sign=None,
    signature=None,
    expires=None,
    code=400,
    code_string="AccessDenied",
    aws_access_token="temp",
):
    response = RequestsResponse()
    response_template = """<?xml version="1.0" encoding="UTF-8"?>
        <Error>
            <Code>{code_string}</Code>
            <Message>{message}</Message>
            <RequestId>{req_id}</RequestId>
            <HostId>{host_id}</HostId>
        </Error>""".format(
        message=message,
        code_string=code_string,
        req_id=short_uid(),
        host_id=short_uid(),
    )

    parsed_response = xmltodict.parse(response_template)
    response.status_code = code

    if signature and string_to_sign or code_string == "SignatureDoesNotMatch":
        bytes_signature = binascii.hexlify(bytes(signature, encoding="utf-8"))
        parsed_response["Error"]["Code"] = code_string
        parsed_response["Error"]["AWSAccessKeyId"] = aws_access_token
        parsed_response["Error"]["StringToSign"] = string_to_sign
        parsed_response["Error"]["SignatureProvided"] = signature
        parsed_response["Error"]["StringToSignBytes"] = "{}".format(
            bytes_signature.decode("utf-8"))
        set_response_content(response, xmltodict.unparse(parsed_response))

    if expires and code_string == "AccessDenied":
        server_time = datetime.datetime.utcnow().isoformat()[:-4]
        expires_isoformat = datetime.datetime.fromtimestamp(
            int(expires)).isoformat()[:-4]
        parsed_response["Error"]["Code"] = code_string
        parsed_response["Error"]["Expires"] = "{}Z".format(expires_isoformat)
        parsed_response["Error"]["ServerTime"] = "{}Z".format(server_time)
        set_response_content(response, xmltodict.unparse(parsed_response))

    if not signature and not expires and code_string == "AccessDenied":
        set_response_content(response, xmltodict.unparse(parsed_response))

    if response._content:
        return response
示例#5
0
def _generate_machine_id() -> str:
    if config.is_in_docker:
        return short_uid()

    # this can potentially be useful when generated on the host using the CLI and then mounted into the container via
    # machine.json
    try:
        if os.path.exists("/etc/machine-id"):
            with open("/etc/machine-id") as fd:
                return md5(str(fd.read()))[:8]
    except Exception:
        pass

    # always fall back to short_uid()
    return short_uid()
示例#6
0
 def __init__(
     self,
     method,
     path,
     data,
     headers,
     api_id=None,
     stage=None,
     context=None,
     auth_info=None,
 ):
     self.method = method
     self.path = path
     self.data = data
     self.headers = headers
     self.context = {
         "requestId": short_uid()
     } if context is None else context
     self.auth_info = {} if auth_info is None else auth_info
     self.apigw_version = None
     self.api_id = api_id
     self.stage = stage
     self.region_name = None
     self.integration = None
     self.resource = None
     self.resource_path = None
     self.path_with_query_string = None
     self.response_templates = {}
     self.stage_variables = {}
     self.path_params = {}
     self.ws_route = None
示例#7
0
 def __init__(self, metadata=None, template=None):
     if template is None:
         template = {}
     self.metadata = metadata or {}
     self.template = template or {}
     self._template_raw = clone_safe(self.template)
     self.template_original = clone_safe(self.template)
     # initialize resources
     for resource_id, resource in self.template_resources.items():
         resource["LogicalResourceId"] = self.template_original["Resources"][resource_id][
             "LogicalResourceId"
         ] = (resource.get("LogicalResourceId") or resource_id)
     # initialize stack template attributes
     stack_id = self.metadata.get("StackId") or aws_stack.cloudformation_stack_arn(
         self.stack_name, short_uid()
     )
     self.template["StackId"] = self.metadata["StackId"] = stack_id
     self.template["Parameters"] = self.template.get("Parameters") or {}
     self.template["Outputs"] = self.template.get("Outputs") or {}
     self.template["Conditions"] = self.template.get("Conditions") or {}
     # initialize metadata
     self.metadata["Parameters"] = self.metadata.get("Parameters") or []
     self.metadata["StackStatus"] = "CREATE_IN_PROGRESS"
     self.metadata["CreationTime"] = self.metadata.get("CreationTime") or timestamp_millis()
     # maps resource id to resource state
     self._resource_states = {}
     # list of stack events
     self.events = []
     # list of stack change sets
     self.change_sets = []
示例#8
0
 def get_parameters_for_import(
     self,
     context: RequestContext,
     key_id: KeyIdType,
     wrapping_algorithm: AlgorithmSpec,
     wrapping_key_spec: WrappingKeySpec,
 ) -> GetParametersForImportResponse:
     key = _generate_data_key_pair({"KeySpec": wrapping_key_spec},
                                   create_cipher=False,
                                   add_to_keys=False)
     import_token = short_uid()
     import_state = KeyImportState(
         key_id=key_id,
         import_token=import_token,
         private_key=key["PrivateKeyPlaintext"],
         public_key=key["PublicKey"],
         wrapping_algo=wrapping_algorithm,
         key_obj=key["_key_"],
     )
     KMSBackend.get().imports[import_token] = import_state
     expiry_date = datetime.datetime.now() + datetime.timedelta(days=100)
     return GetParametersForImportResponse(
         KeyId=key_id,
         ImportToken=to_bytes(import_state.import_token),
         PublicKey=import_state.public_key,
         ParametersValidTo=expiry_date,
     )
示例#9
0
    def create_stack_instances(
        self,
        context: RequestContext,
        request: CreateStackInstancesInput,
    ) -> CreateStackInstancesOutput:
        state = CloudFormationRegion.get()

        set_name = request.get("StackSetName")
        stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name]

        if not stack_set:
            return not_found_error(f'Stack set named "{set_name}" does not exist')

        stack_set = stack_set[0]
        op_id = request.get("OperationId") or short_uid()
        sset_meta = stack_set.metadata
        accounts = request["Accounts"]
        regions = request["Regions"]

        stacks_to_await = []
        for account in accounts:
            for region in regions:
                # deploy new stack
                LOG.debug('Deploying instance for stack set "%s" in region "%s"', set_name, region)
                cf_client = aws_stack.connect_to_service("cloudformation", region_name=region)
                kwargs = select_attributes(sset_meta, ["TemplateBody"]) or select_attributes(
                    sset_meta, ["TemplateURL"]
                )
                stack_name = f"sset-{set_name}-{account}"
                result = cf_client.create_stack(StackName=stack_name, **kwargs)
                stacks_to_await.append((stack_name, region))
                # store stack instance
                instance = {
                    "StackSetId": sset_meta["StackSetId"],
                    "OperationId": op_id,
                    "Account": account,
                    "Region": region,
                    "StackId": result["StackId"],
                    "Status": "CURRENT",
                    "StackInstanceStatus": {"DetailedStatus": "SUCCEEDED"},
                }
                instance = StackInstance(instance)
                stack_set.stack_instances.append(instance)

        # wait for completion of stack
        for stack in stacks_to_await:
            aws_stack.await_stack_completion(stack[0], region_name=stack[1])

        # record operation
        operation = {
            "OperationId": op_id,
            "StackSetId": stack_set.metadata["StackSetId"],
            "Action": "CREATE",
            "Status": "SUCCEEDED",
        }
        stack_set.operations[op_id] = operation

        return CreateStackInstancesOutput(OperationId=op_id)
示例#10
0
    def test_object_created_copy(
        self,
        s3_client,
        sqs_client,
        s3_create_bucket,
        sqs_create_queue,
        s3_create_sqs_bucket_notification,
        snapshot,
    ):
        snapshot.add_transformer(snapshot.transform.sqs_api())
        snapshot.add_transformer(snapshot.transform.s3_api())
        snapshot.add_transformer(
            snapshot.transform.jsonpath("$..s3.object.key", "object-key"))

        # setup fixture
        bucket_name = s3_create_bucket()
        queue_url = sqs_create_queue()
        s3_create_sqs_bucket_notification(bucket_name, queue_url,
                                          ["s3:ObjectCreated:Copy"])

        src_key = "src-dest-%s" % short_uid()
        dest_key = "key-dest-%s" % short_uid()

        s3_client.put_object(Bucket=bucket_name, Key=src_key, Body="something")

        assert not sqs_collect_s3_events(
            sqs_client, queue_url, 0,
            timeout=1), "unexpected event triggered for put_object"

        s3_client.copy_object(
            Bucket=bucket_name,
            CopySource={
                "Bucket": bucket_name,
                "Key": src_key
            },
            Key=dest_key,
        )

        events = sqs_collect_s3_events(sqs_client, queue_url, 1)
        assert len(events) == 1, f"unexpected number of events in {events}"
        snapshot.match("receive_messages", {"messages": events})
        assert events[0]["eventSource"] == "aws:s3"
        assert events[0]["eventName"] == "ObjectCreated:Copy"
        assert events[0]["s3"]["bucket"]["name"] == bucket_name
        assert events[0]["s3"]["object"]["key"] == dest_key
示例#11
0
    def test_object_tagging_delete_event(
        self,
        s3_client,
        sqs_client,
        s3_create_bucket,
        sqs_create_queue,
        s3_create_sqs_bucket_notification,
        snapshot,
    ):
        snapshot.add_transformer(snapshot.transform.sqs_api())
        snapshot.add_transformer(snapshot.transform.s3_api())
        snapshot.add_transformer(
            snapshot.transform.jsonpath("$..s3.object.key", "object-key"))

        # setup fixture
        bucket_name = s3_create_bucket()
        queue_url = sqs_create_queue()
        s3_create_sqs_bucket_notification(bucket_name, queue_url,
                                          ["s3:ObjectTagging:Delete"])

        dest_key = "key-dest-%s" % short_uid()

        s3_client.put_object(Bucket=bucket_name,
                             Key=dest_key,
                             Body="FooBarBlitz")

        assert not sqs_collect_s3_events(
            sqs_client, queue_url, 0,
            timeout=1), "unexpected event triggered for put_object"

        s3_client.put_object_tagging(
            Bucket=bucket_name,
            Key=dest_key,
            Tagging={
                "TagSet": [
                    {
                        "Key": "swallow_type",
                        "Value": "african"
                    },
                ]
            },
        )

        s3_client.delete_object_tagging(
            Bucket=bucket_name,
            Key=dest_key,
        )

        events = sqs_collect_s3_events(sqs_client, queue_url, 1)
        assert len(events) == 1, f"unexpected number of events in {events}"
        snapshot.match("receive_messages", {"messages": events})

        assert events[0]["eventSource"] == "aws:s3"
        assert events[0]["eventName"] == "ObjectTagging:Delete"
        assert events[0]["s3"]["bucket"]["name"] == bucket_name
        assert events[0]["s3"]["object"]["key"] == dest_key
示例#12
0
    def create_stack_set(
        self, context: RequestContext, request: CreateStackSetInput
    ) -> CreateStackSetOutput:
        state = CloudFormationRegion.get()
        stack_set = StackSet(request)
        stack_set_id = short_uid()
        stack_set.metadata["StackSetId"] = stack_set_id
        state.stack_sets[stack_set_id] = stack_set

        return CreateStackSetOutput(StackSetId=stack_set_id)
示例#13
0
 def create_vpc_link(
     self,
     context: RequestContext,
     name: String,
     target_arns: ListOfString,
     description: String = None,
     tags: MapOfStringToString = None,
 ) -> VpcLink:
     region_details = APIGatewayRegion.get()
     link_id = short_uid()
     entry = {"id": link_id, "status": "AVAILABLE"}
     region_details.vpc_links[link_id] = entry
     result = to_vpc_link_response_json(entry)
     return VpcLink(**result)
示例#14
0
def events_handler_put_events(self):
    entries = self._get_param("Entries")

    # keep track of events for local integration testing
    if config.is_local_test_mode():
        TEST_EVENTS_CACHE.extend(entries)

    events = list(map(lambda event: {"event": event, "uuid": str(long_uid())}, entries))

    _dump_events_to_files(events)
    event_rules = self.events_backend.rules

    for event_envelope in events:
        event = event_envelope["event"]
        event_bus = event.get("EventBusName") or DEFAULT_EVENT_BUS_NAME

        matchine_rules = [r for r in event_rules.values() if r.event_bus_name == event_bus]
        if not matchine_rules:
            continue

        formatted_event = {
            "version": "0",
            "id": event_envelope["uuid"],
            "detail-type": event.get("DetailType"),
            "source": event.get("Source"),
            "account": TEST_AWS_ACCOUNT_ID,
            "time": datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
            "region": self.region,
            "resources": event.get("Resources", []),
            "detail": json.loads(event.get("Detail", "{}")),
        }

        targets = []
        for rule in matchine_rules:
            if filter_event_based_on_event_format(self, rule.name, formatted_event):
                targets.extend(self.events_backend.list_targets_by_rule(rule.name)["Targets"])

        # process event
        process_events(formatted_event, targets)

    content = {
        "FailedEntryCount": 0,  # TODO: dynamically set proper value when refactoring
        "Entries": list(map(lambda event: {"EventId": event["uuid"]}, events)),
    }

    self.response_headers.update(
        {"Content-Type": APPLICATION_AMZ_JSON_1_1, "x-amzn-RequestId": short_uid()}
    )

    return json.dumps(content), self.response_headers
示例#15
0
    def create_authorizer(
        self, context: RequestContext, request: CreateAuthorizerRequest
    ) -> Authorizer:
        region_details = APIGatewayRegion.get()

        api_id = request["restApiId"]
        authorizer_id = short_uid()[:6]  # length 6 to make TF tests pass
        result = deepcopy(request)

        result["id"] = authorizer_id
        result = normalize_authorizer(result)
        region_details.authorizers.setdefault(api_id, []).append(result)

        result = to_authorizer_response_json(api_id, result)
        return Authorizer(**result)
示例#16
0
    def __init__(self, params=None, template=None):
        if template is None:
            template = {}
        if params is None:
            params = {}
        super(StackChangeSet, self).__init__(params, template)

        name = self.metadata["ChangeSetName"]
        if not self.metadata.get("ChangeSetId"):
            self.metadata["ChangeSetId"] = aws_stack.cf_change_set_arn(
                name, change_set_id=short_uid()
            )

        stack = self.stack = find_stack(self.metadata["StackName"])
        self.metadata["StackId"] = stack.stack_id
        self.metadata["Status"] = "CREATE_PENDING"
示例#17
0
 def generate_client_certificate(
     self, context: RequestContext, description: String = None, tags: MapOfStringToString = None
 ) -> ClientCertificate:
     region_details = APIGatewayRegion.get()
     cert_id = short_uid()
     creation_time = now_utc()
     entry = {
         "description": description,
         "tags": tags,
         "clientCertificateId": cert_id,
         "createdDate": creation_time,
         "expirationDate": creation_time + 60 * 60 * 24 * 30,  # assume 30 days validity
         "pemEncodedCertificate": "testcert-123",  # TODO return proper certificate!
     }
     region_details.client_certificates[cert_id] = entry
     result = to_client_cert_response_json(entry)
     return ClientCertificate(**result)
示例#18
0
def get_stream_info(
    stream_name,
    log_file=None,
    shards=None,
    env=None,
    endpoint_url=None,
    ddb_lease_table_suffix=None,
    env_vars=None,
):
    if env_vars is None:
        env_vars = {}
    if not ddb_lease_table_suffix:
        ddb_lease_table_suffix = DEFAULT_DDB_LEASE_TABLE_SUFFIX
    # construct stream info
    env = aws_stack.get_environment(env)
    props_file = os.path.join(tempfile.gettempdir(),
                              "kclipy.%s.properties" % short_uid())
    # make sure to convert stream ARN to stream name
    stream_name = aws_stack.kinesis_stream_name(stream_name)
    app_name = "%s%s" % (stream_name, ddb_lease_table_suffix)
    stream_info = {
        "name": stream_name,
        "region": aws_stack.get_region(),
        "shards": shards,
        "properties_file": props_file,
        "log_file": log_file,
        "app_name": app_name,
        "env_vars": env_vars,
    }
    # set local connection
    if aws_stack.is_local_env(env):
        stream_info["conn_kwargs"] = {
            "host": LOCALHOST,
            "port": config.service_port("kinesis"),
            "is_secure": bool(config.USE_SSL),
        }
    if endpoint_url:
        if "conn_kwargs" not in stream_info:
            stream_info["conn_kwargs"] = {}
        url = urlparse(endpoint_url)
        stream_info["conn_kwargs"]["host"] = url.hostname
        stream_info["conn_kwargs"]["port"] = url.port
        stream_info["conn_kwargs"]["is_secure"] = url.scheme == "https"
    return stream_info
示例#19
0
 def update_stack_set(
     self, context: RequestContext, request: UpdateStackSetInput
 ) -> UpdateStackSetOutput:
     state = CloudFormationRegion.get()
     set_name = request.get("StackSetName")
     stack_set = [sset for sset in state.stack_sets.values() if sset.stack_set_name == set_name]
     if not stack_set:
         return not_found_error(f'Stack set named "{set_name}" does not exist')
     stack_set = stack_set[0]
     stack_set.metadata.update(request)
     op_id = request.get("OperationId") or short_uid()
     operation = {
         "OperationId": op_id,
         "StackSetId": stack_set.metadata["StackSetId"],
         "Action": "UPDATE",
         "Status": "SUCCEEDED",
     }
     stack_set.operations[op_id] = operation
     return UpdateStackSetOutput(OperationId=op_id)
示例#20
0
def requests_error_response_xml(
    message: str,
    code: Optional[int] = 400,
    code_string: Optional[str] = "InvalidParameter",
    service: Optional[str] = None,
    xmlns: Optional[str] = None,
):
    response = RequestsResponse()
    xmlns = xmlns or "http://%s.amazonaws.com/doc/2010-03-31/" % service
    response._content = """<ErrorResponse xmlns="{xmlns}"><Error>
        <Type>Sender</Type>
        <Code>{code_string}</Code>
        <Message>{message}</Message>
        </Error><RequestId>{req_id}</RequestId>
        </ErrorResponse>""".format(xmlns=xmlns,
                                   message=message,
                                   code_string=code_string,
                                   req_id=short_uid())
    response.status_code = code
    return response
示例#21
0
    def create_documentation_part(
        self,
        context: RequestContext,
        rest_api_id: String,
        location: DocumentationPartLocation,
        properties: String,
    ) -> DocumentationPart:
        region_details = APIGatewayRegion.get()

        entity_id = short_uid()[:6]  # length 6 for AWS parity / Terraform compatibility
        entry = {
            "id": entity_id,
            "restApiId": rest_api_id,
            "location": location,
            "properties": properties,
        }

        region_details.documentation_parts.setdefault(rest_api_id, []).append(entry)

        result = to_documentation_part_response_json(rest_api_id, entry)
        return DocumentationPart(**result)
示例#22
0
    def create_request_validator(
        self,
        context: RequestContext,
        rest_api_id: String,
        name: String = None,
        validate_request_body: Boolean = None,
        validate_request_parameters: Boolean = None,
    ) -> RequestValidator:
        region_details = APIGatewayRegion.get()

        # length 6 for AWS parity and TF compatibility
        validator_id = short_uid()[:6]

        entry = {
            "id": validator_id,
            "name": name,
            "restApiId": rest_api_id,
            "validateRequestBody": validate_request_body,
            "validateRequestParameters": validate_request_parameters,
        }
        region_details.validators.setdefault(rest_api_id, []).append(entry)

        return RequestValidator(**entry)
示例#23
0
def render_velocity_template(template, context, variables=None, as_json=False):
    if variables is None:
        variables = {}

    if not template:
        return template

    # fix "#set" commands
    template = re.sub(r"(^|\n)#\s+set(.*)", r"\1#set\2", template,
                      re.MULTILINE)

    # enable syntax like "test#${foo.bar}"
    empty_placeholder = " __pLaCe-HoLdEr__ "
    template = re.sub(
        r"([^\s]+)#\$({)?(.*)",
        r"\1#%s$\2\3" % empty_placeholder,
        template,
        re.MULTILINE,
    )

    # add extensions for common string functions below

    class ExtendedString(str):
        def trim(self, *args, **kwargs):
            return ExtendedString(self.strip(*args, **kwargs))

        def toLowerCase(self, *args, **kwargs):
            return ExtendedString(self.lower(*args, **kwargs))

        def toUpperCase(self, *args, **kwargs):
            return ExtendedString(self.upper(*args, **kwargs))

    def apply(obj, **kwargs):
        if isinstance(obj, dict):
            for k, v in obj.items():
                if isinstance(v, str):
                    obj[k] = ExtendedString(v)
        return obj

    # loop through the variables and enable certain additional util functions (e.g., string utils)
    variables = variables or {}
    recurse_object(variables, apply)

    # prepare and render template
    context_var = variables.get("context") or {}
    context_var.setdefault("requestId", short_uid())
    t = airspeed.Template(template)
    var_map = {
        "input": VelocityInput(context),
        "util": VelocityUtil(),
        "context": context_var,
    }
    var_map.update(variables or {})
    replaced = t.merge(var_map)

    # revert temporary changes from the fixes above
    replaced = replaced.replace(empty_placeholder, "")

    if as_json:
        replaced = json.loads(replaced)
    return replaced
示例#24
0
 def mountable_tmp_file():
     f = os.path.join(config.dirs.tmp, short_uid())
     TMP_FILES.append(f)
     return f
示例#25
0
def start_kcl_client_process(
    stream_name,
    listener_script,
    log_file=None,
    env=None,
    configs=None,
    endpoint_url=None,
    ddb_lease_table_suffix=None,
    env_vars=None,
    region_name=None,
    kcl_log_level=DEFAULT_KCL_LOG_LEVEL,
    log_subscribers=None,
):
    if configs is None:
        configs = {}
    if env_vars is None:
        env_vars = {}
    if log_subscribers is None:
        log_subscribers = []
    env = aws_stack.get_environment(env)
    # make sure to convert stream ARN to stream name
    stream_name = aws_stack.kinesis_stream_name(stream_name)
    # decide which credentials provider to use
    credentialsProvider = None
    if ("AWS_ASSUME_ROLE_ARN" in os.environ or "AWS_ASSUME_ROLE_ARN"
            in env_vars) and ("AWS_ASSUME_ROLE_SESSION_NAME" in os.environ
                              or "AWS_ASSUME_ROLE_SESSION_NAME" in env_vars):
        # use special credentials provider that can assume IAM roles and handle temporary STS auth tokens
        credentialsProvider = "cloud.localstack.DefaultSTSAssumeRoleSessionCredentialsProvider"
        # pass through env variables to child process
        for var_name in [
                "AWS_ASSUME_ROLE_ARN",
                "AWS_ASSUME_ROLE_SESSION_NAME",
                "AWS_ACCESS_KEY_ID",
                "AWS_SECRET_ACCESS_KEY",
                "AWS_SESSION_TOKEN",
        ]:
            if var_name in os.environ and var_name not in env_vars:
                env_vars[var_name] = os.environ[var_name]
    if aws_stack.is_local_env(env):
        # need to disable CBOR protocol, enforce use of plain JSON,
        # see https://github.com/mhart/kinesalite/issues/31
        env_vars["AWS_CBOR_DISABLE"] = "true"
    if kcl_log_level or (len(log_subscribers) > 0):
        if not log_file:
            log_file = LOG_FILE_PATTERN.replace("*", short_uid())
            TMP_FILES.append(log_file)
        run("touch %s" % log_file)
        # start log output reader thread which will read the KCL log
        # file and print each line to stdout of this process...
        reader_thread = OutputReaderThread({
            "file": log_file,
            "level": kcl_log_level,
            "log_prefix": "KCL",
            "log_subscribers": log_subscribers,
        })
        reader_thread.start()

    # construct stream info
    stream_info = get_stream_info(
        stream_name,
        log_file,
        env=env,
        endpoint_url=endpoint_url,
        ddb_lease_table_suffix=ddb_lease_table_suffix,
        env_vars=env_vars,
    )
    props_file = stream_info["properties_file"]
    # set kcl config options
    kwargs = {"metricsLevel": "NONE", "initialPositionInStream": "LATEST"}
    # set parameters for local connection
    if aws_stack.is_local_env(env):
        kwargs[
            "kinesisEndpoint"] = f"{LOCALHOST}:{config.service_port('kinesis')}"
        kwargs[
            "dynamodbEndpoint"] = f"{LOCALHOST}:{config.service_port('dynamodb')}"
        kwargs["kinesisProtocol"] = config.get_protocol()
        kwargs["dynamodbProtocol"] = config.get_protocol()
        kwargs["disableCertChecking"] = "true"
    kwargs.update(configs)
    # create config file
    kclipy_helper.create_config_file(
        config_file=props_file,
        executableName=listener_script,
        streamName=stream_name,
        applicationName=stream_info["app_name"],
        credentialsProvider=credentialsProvider,
        region_name=region_name,
        **kwargs,
    )
    TMP_FILES.append(props_file)
    # start stream consumer
    stream = KinesisStream(id=stream_name, params=stream_info)
    thread_consumer = KinesisProcessorThread.start_consumer(stream)
    TMP_THREADS.append(thread_consumer)
    return thread_consumer
示例#26
0
def listen_to_kinesis(
    stream_name,
    listener_func=None,
    processor_script=None,
    events_file=None,
    endpoint_url=None,
    log_file=None,
    configs=None,
    env=None,
    ddb_lease_table_suffix=None,
    env_vars=None,
    kcl_log_level=DEFAULT_KCL_LOG_LEVEL,
    log_subscribers=None,
    wait_until_started=False,
    fh_d_stream=None,
    region_name=None,
):
    """
    High-level function that allows to subscribe to a Kinesis stream
    and receive events in a listener function. A KCL client process is
    automatically started in the background.
    """
    if configs is None:
        configs = {}
    if env_vars is None:
        env_vars = {}
    if log_subscribers is None:
        log_subscribers = []
    env = aws_stack.get_environment(env)
    if not events_file:
        events_file = EVENTS_FILE_PATTERN.replace("*", short_uid())
        TMP_FILES.append(events_file)
    if not processor_script:
        processor_script = generate_processor_script(events_file,
                                                     log_file=log_file)

    rm_rf(events_file)
    # start event reader thread (this process)
    ready_mutex = threading.Semaphore(0)
    thread = EventFileReaderThread(events_file,
                                   listener_func,
                                   ready_mutex=ready_mutex,
                                   fh_d_stream=fh_d_stream)
    thread.start()
    # Wait until the event reader thread is ready (to avoid 'Connection refused' error on the UNIX socket)
    ready_mutex.acquire()
    # start KCL client (background process)
    if processor_script[-4:] == ".pyc":
        processor_script = processor_script[0:-1]
    # add log listener that notifies when KCL is started
    if wait_until_started:
        listener = KclStartedLogListener()
        log_subscribers.append(listener)

    process = start_kcl_client_process(
        stream_name,
        processor_script,
        endpoint_url=endpoint_url,
        log_file=log_file,
        configs=configs,
        env=env,
        ddb_lease_table_suffix=ddb_lease_table_suffix,
        env_vars=env_vars,
        kcl_log_level=kcl_log_level,
        log_subscribers=log_subscribers,
        region_name=region_name,
    )

    if wait_until_started:
        # Wait at most 90 seconds for initialization. Note that creating the DDB table can take quite a bit
        try:
            listener.sync_init.get(block=True, timeout=90)
        except Exception:
            raise Exception("Timeout when waiting for KCL initialization.")
        # wait at most 30 seconds for shard lease notification
        try:
            listener.sync_take_shard.get(block=True, timeout=30)
        except Exception:
            # this merely means that there is no shard available to take. Do nothing.
            pass

    return process
示例#27
0
def generate_processor_script(events_file, log_file=None):
    script_file = os.path.join(tempfile.gettempdir(),
                               "kclipy.%s.processor.py" % short_uid())
    if log_file:
        log_file = "'%s'" % log_file
    else:
        log_file = "None"
    content = """#!/usr/bin/env python
import os, sys, glob, json, socket, time, logging, subprocess, tempfile
logging.basicConfig(level=logging.INFO)
for path in glob.glob('%s/lib/python*/site-packages'):
    sys.path.insert(0, path)
sys.path.insert(0, '%s')
from localstack.config import DEFAULT_ENCODING
from localstack.utils.kinesis import kinesis_connector
from localstack.utils.time import timestamp
events_file = '%s'
log_file = %s
error_log = os.path.join(tempfile.gettempdir(), 'kclipy.error.log')
if __name__ == '__main__':
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

    num_tries = 3
    sleep_time = 2
    error = None
    for i in range(0, num_tries):
        try:
            sock.connect(events_file)
            error = None
            break
        except Exception as e:
            error = e
            if i < num_tries:
                msg = '%%s: Unable to connect to UNIX socket. Retrying.' %% timestamp()
                subprocess.check_output('echo "%%s" >> %%s' %% (msg, error_log), shell=True)
                time.sleep(sleep_time)
    if error:
        print("WARN: Unable to connect to UNIX socket after retrying: %%s" %% error)
        raise error

    def receive_msg(records, checkpointer, shard_id):
        try:
            # records is a list of amazon_kclpy.messages.Record objects -> convert to JSON
            records_dicts = [j._json_dict for j in records]
            message_to_send = {'shard_id': shard_id, 'records': records_dicts}
            string_to_send = '%%s\\n' %% json.dumps(message_to_send)
            bytes_to_send = string_to_send.encode(DEFAULT_ENCODING)
            sock.send(bytes_to_send)
        except Exception as e:
            msg = "WARN: Unable to forward event: %%s" %% e
            print(msg)
            subprocess.check_output('echo "%%s" >> %%s' %% (msg, error_log), shell=True)
    kinesis_connector.KinesisProcessor.run_processor(log_file=log_file, processor_func=receive_msg)
    """ % (
        LOCALSTACK_VENV_FOLDER,
        LOCALSTACK_ROOT_FOLDER,
        events_file,
        log_file,
    )
    save_file(script_file, content)
    chmod_r(script_file, 0o755)
    TMP_FILES.append(script_file)
    return script_file