Esempio n. 1
0
def set_api_id_stage_invocation_path(
    invocation_context: ApiInvocationContext, ) -> ApiInvocationContext:
    # skip if all details are already available
    values = (
        invocation_context.api_id,
        invocation_context.stage,
        invocation_context.path_with_query_string,
    )
    if all(values):
        return invocation_context

    # skip if this is a websocket request
    if invocation_context.is_websocket_request():
        return invocation_context

    path = invocation_context.path
    headers = invocation_context.headers

    path_match = re.search(PATH_REGEX_USER_REQUEST, path)
    host_header = headers.get(HEADER_LOCALSTACK_EDGE_URL,
                              "") or headers.get("Host") or ""
    host_match = re.search(HOST_REGEX_EXECUTE_API, host_header)
    test_invoke_match = re.search(PATH_REGEX_TEST_INVOKE_API, path)
    if path_match:
        api_id = path_match.group(1)
        stage = path_match.group(2)
        relative_path_w_query_params = "/%s" % path_match.group(3)
    elif host_match:
        api_id = extract_api_id_from_hostname_in_url(host_header)
        stage = path.strip("/").split("/")[0]
        relative_path_w_query_params = "/%s" % path.lstrip("/").partition(
            "/")[2]
    elif test_invoke_match:
        # special case: fetch the resource details for TestInvokeApi invocations
        stage = None
        region_name = invocation_context.region_name
        api_id = test_invoke_match.group(1)
        resource_id = test_invoke_match.group(2)
        query_string = test_invoke_match.group(4) or ""
        apigateway = aws_stack.connect_to_service(service_name="apigateway",
                                                  region_name=region_name)
        resource = apigateway.get_resource(restApiId=api_id,
                                           resourceId=resource_id)
        resource_path = resource.get("path")
        relative_path_w_query_params = f"{resource_path}{query_string}"
    else:
        raise Exception(
            f"Unable to extract API Gateway details from request: {path} {dict(headers)}"
        )
    if api_id:
        # set current region in request thread local, to ensure aws_stack.get_region() works properly
        if getattr(THREAD_LOCAL, "request_context", None) is not None:
            THREAD_LOCAL.request_context.headers[
                MARKER_APIGW_REQUEST_REGION] = API_REGIONS.get(api_id, "")

    # set details in invocation context
    invocation_context.api_id = api_id
    invocation_context.stage = stage
    invocation_context.path_with_query_string = relative_path_w_query_params
    return invocation_context
Esempio n. 2
0
def invoke_rest_api(invocation_context: ApiInvocationContext):
    invocation_path = invocation_context.path_with_query_string
    raw_path = invocation_context.path or invocation_path
    method = invocation_context.method
    headers = invocation_context.headers

    # run gateway authorizers for this request
    authorize_invocation(invocation_context)

    extracted_path, resource = get_target_resource_details(invocation_context)
    if not resource:
        return make_error_response(
            "Unable to find path %s" % invocation_context.path, 404)

    # validate request
    validator = RequestValidator(invocation_context,
                                 aws_stack.connect_to_service("apigateway"))
    if not validator.is_request_valid():
        return make_error_response("Invalid request body", 400)

    api_key_required = resource.get("resourceMethods",
                                    {}).get(method, {}).get("apiKeyRequired")
    if not is_api_key_valid(api_key_required, headers,
                            invocation_context.stage):
        return make_error_response("Access denied - invalid API key", 403)

    integrations = resource.get("resourceMethods", {})
    integration = integrations.get(method, {})
    if not integration:
        # HttpMethod: '*'
        # ResourcePath: '/*' - produces 'X-AMAZON-APIGATEWAY-ANY-METHOD'
        integration = integrations.get("ANY", {}) or integrations.get(
            "X-AMAZON-APIGATEWAY-ANY-METHOD", {})
    integration = integration.get("methodIntegration")
    if not integration:
        if method == "OPTIONS" and "Origin" in headers:
            # default to returning CORS headers if this is an OPTIONS request
            return get_cors_response(headers)
        return make_error_response(
            "Unable to find integration for: %s %s (%s)" %
            (method, invocation_path, raw_path),
            404,
        )

    res_methods = resource.get("resourceMethods", {})
    meth_integration = res_methods.get(method, {}).get("methodIntegration", {})
    int_responses = meth_integration.get("integrationResponses", {})
    response_templates = int_responses.get("200",
                                           {}).get("responseTemplates", {})

    # update fields in invocation context, then forward request to next handler
    invocation_context.resource = resource
    invocation_context.resource_path = extracted_path
    invocation_context.response_templates = response_templates
    invocation_context.integration = integration

    return invoke_rest_api_integration(invocation_context)
Esempio n. 3
0
    def render(self, api_context: ApiInvocationContext):
        LOG.info("Method request body before transformations: %s",
                 to_str(api_context.data_as_string()))
        request_templates = api_context.integration.get("requestTemplates", {})
        template = request_templates.get(APPLICATION_JSON, {})
        if not template:
            return api_context.data_as_string()

        variables = self.build_variables_mapping(api_context)
        result = self.render_vtl(template, variables=variables)
        LOG.info(f"Endpoint request body after transformations:\n{result}")
        return result
Esempio n. 4
0
 def build_variables_mapping(api_context: ApiInvocationContext):
     # TODO: make this (dict) an object so usages of "render_vtl" variables are defined
     return {
         "context": api_context.context or {},
         "stage_variables": api_context.stage_variables or {},
         "input": {
             "body": api_context.data_as_string(),
             "params": {
                 "path": api_context.path_params,
                 "querystring": api_context.query_params(),
                 "header": api_context.headers,
             },
         },
     }
Esempio n. 5
0
def to_invocation_context(
        request: Request,
        url_params: Dict[str, Any] = None) -> ApiInvocationContext:
    """
    Converts an HTTP Request object into an ApiInvocationContext.

    :param request: the original request
    :param url_params: the parameters extracted from the URL matching rules
    :return: the ApiInvocationContext
    """
    method = request.method
    path = request.full_path if request.query_string else request.path
    data = restore_payload(request)
    headers = Headers(request.headers)

    # adjust the X-Forwarded-For header
    x_forwarded_for = headers.getlist("X-Forwarded-For")
    x_forwarded_for.append(request.remote_addr)
    x_forwarded_for.append(request.host)
    headers["X-Forwarded-For"] = ", ".join(x_forwarded_for)

    # set the x-localstack-edge header, it is used to parse the domain
    headers[HEADER_LOCALSTACK_EDGE_URL] = request.host_url.strip("/")

    # FIXME: Use the already parsed url params instead of parsing them into the ApiInvocationContext part-by-part.
    #   We already would have all params at hand to avoid _all_ the parsing, but the parsing
    #   has side-effects (f.e. setting the region in a thread local)!
    #   It would be best to use a small (immutable) context for the already parsed params and the Request object
    #   and use it everywhere.
    return ApiInvocationContext(method,
                                path,
                                data,
                                headers,
                                stage=url_params.get("stage"))
Esempio n. 6
0
    def render(self, api_context: ApiInvocationContext):
        response = api_context.response
        integration = api_context.integration
        # we set context data with the response content because later on we use context data as
        # the body field in the template. We need to improve this by using the right source
        # depending on the type of templates.
        api_context.data = response._content
        int_responses = integration.get("integrationResponses") or {}
        if not int_responses:
            return response._content
        entries = list(int_responses.keys())
        return_code = str(response.status_code)
        if return_code not in entries and len(entries) > 1:
            LOG.info("Found multiple integration response status codes: %s",
                     entries)
            return response._content
        return_code = entries[0]

        response_templates = int_responses[return_code].get(
            "responseTemplates", {})
        template = response_templates.get(APPLICATION_JSON, {})
        if not template:
            return response

        variables = self.build_variables_mapping(api_context)
        response._content = self.render_vtl(template, variables=variables)
        LOG.info("Endpoint response body after transformations:\n%s",
                 response._content)
        return response._content
Esempio n. 7
0
    def render(self, api_context: ApiInvocationContext,
               **kwargs) -> Union[bytes, str]:
        # XXX: keep backwards compatibility until we migrate all integrations to this new classes
        # api_context contains a response object that we want slowly remove from it
        data = kwargs["response"] if "response" in kwargs else ""
        response = data or api_context.response
        integration = api_context.integration
        # we set context data with the response content because later on we use context data as
        # the body field in the template. We need to improve this by using the right source
        # depending on the type of templates.
        api_context.data = response._content

        integration_responses = integration.get("integrationResponses") or {}
        if not integration_responses:
            return response._content
        entries = list(integration_responses.keys())
        return_code = str(response.status_code)
        if return_code not in entries and len(entries) > 1:
            LOG.info("Found multiple integration response status codes: %s",
                     entries)
            return response._content
        return_code = entries[0]

        response_templates = integration_responses[return_code].get(
            "responseTemplates", {})
        template = response_templates.get(APPLICATION_JSON, {})
        if not template:
            return response._content

        variables = self.build_variables_mapping(api_context)
        response._content = self.render_vtl(template, variables=variables)
        LOG.info("Endpoint response body after transformations:\n%s",
                 response._content)
        return response._content
Esempio n. 8
0
def invoke_rest_api_integration(invocation_context: ApiInvocationContext):
    try:
        response = invoke_rest_api_integration_backend(invocation_context)
        # TODO remove this setter once all the integrations are migrated to the new response
        #  handling
        invocation_context.response = response
        response = apply_response_parameters(invocation_context)
        return response
    except Exception as e:
        msg = f"Error invoking integration for API Gateway ID '{invocation_context.api_id}': {e}"
        LOG.exception(msg)
        return make_error_response(msg, 400)
Esempio n. 9
0
    def forward_request(self, method, path, data, headers):
        invocation_context = ApiInvocationContext(method, path, data, headers)

        forwarded_for = headers.get(HEADER_LOCALSTACK_EDGE_URL, "")
        if re.match(PATH_REGEX_USER_REQUEST, path) or "execute-api" in forwarded_for:
            result = invoke_rest_api_from_request(invocation_context)
            if result is not None:
                return result

        if helpers.is_test_invoke_method(method, path):
            return self._handle_test_invoke_method(invocation_context)
        return super().forward_request(method, path, data, headers)
Esempio n. 10
0
def get_target_resource_details(
        invocation_context: ApiInvocationContext) -> Tuple[str, Dict]:
    """Look up and return the API GW resource (path pattern + resource dict) for the given invocation context."""
    path_map = helpers.get_rest_api_paths(
        rest_api_id=invocation_context.api_id,
        region_name=invocation_context.region_name)
    relative_path = invocation_context.invocation_path
    try:
        extracted_path, resource = get_resource_for_path(path=relative_path,
                                                         path_map=path_map)
        invocation_context.resource = resource
        return extracted_path, resource
    except Exception:
        return None, None
Esempio n. 11
0
 def invoke(self, invocation_context: ApiInvocationContext) -> Response:
     invocation_context.context = get_event_request_context(
         invocation_context)
     try:
         payload = self.request_templates.render(invocation_context)
     except Exception as e:
         LOG.warning("Failed to apply template for SNS integration", e)
         raise
     uri = (invocation_context.integration.get("uri")
            or invocation_context.integration.get("integrationUri") or "")
     region_name = uri.split(":")[3]
     headers = aws_stack.mock_aws_request_headers(service="sns",
                                                  region_name=region_name)
     return make_http_request(config.service_url("sns"),
                              method="POST",
                              headers=headers,
                              data=payload)
Esempio n. 12
0
    def forward_request(self, method, path, data, headers):
        invocation_context = ApiInvocationContext(method, path, data, headers)

        forwarded_for = headers.get(HEADER_LOCALSTACK_EDGE_URL, "")
        if re.match(PATH_REGEX_USER_REQUEST,
                    path) or "execute-api" in forwarded_for:
            result = invoke_rest_api_from_request(invocation_context)
            if result is not None:
                return result

        data = data and json.loads(to_str(data))

        if re.match(PATH_REGEX_AUTHORIZERS, path):
            return handle_authorizers(method, path, data, headers)

        if re.match(PATH_REGEX_DOC_PARTS, path):
            return handle_documentation_parts(method, path, data, headers)

        if re.match(PATH_REGEX_VALIDATORS, path):
            return handle_validators(method, path, data, headers)

        if re.match(PATH_REGEX_RESPONSES, path):
            return handle_gateway_responses(method, path, data, headers)

        if re.match(PATH_REGEX_PATH_MAPPINGS, path):
            return handle_base_path_mappings(method, path, data, headers)

        if helpers.is_test_invoke_method(method, path):
            # if call is from test_invoke_api then use http_method to find the integration,
            #   as test_invoke_api makes a POST call to request the test invocation
            match = re.match(PATH_REGEX_TEST_INVOKE_API, path)
            invocation_context.method = match[3]
            if data:
                orig_data = data
                path_with_query_string = orig_data.get("pathWithQueryString",
                                                       None)
                if path_with_query_string:
                    invocation_context.path_with_query_string = path_with_query_string
                invocation_context.data = data.get("body")
                invocation_context.headers = orig_data.get("headers", {})
            result = invoke_rest_api_from_request(invocation_context)
            result = {
                "status": result.status_code,
                "body": to_str(result.content),
                "headers": dict(result.headers),
            }
            return result

        return True
Esempio n. 13
0
def invoke_rest_api_integration_backend(
        invocation_context: ApiInvocationContext):
    # define local aliases from invocation context
    invocation_path = invocation_context.path_with_query_string
    method = invocation_context.method
    data = invocation_context.data
    headers = invocation_context.headers
    api_id = invocation_context.api_id
    stage = invocation_context.stage
    resource_path = invocation_context.resource_path
    integration = invocation_context.integration
    integration_response = integration.get("integrationResponses", {})
    response_templates = integration_response.get("200", {}).get(
        "responseTemplates", {})
    # extract integration type and path parameters
    relative_path, query_string_params = extract_query_string_params(
        path=invocation_path)
    integration_type_orig = integration.get("type") or integration.get(
        "integrationType") or ""
    integration_type = integration_type_orig.upper()
    uri = integration.get("uri") or integration.get("integrationUri") or ""
    # XXX we need replace the internal Authorization header with an Authorization header set from
    # the customer, even if it's empty that's what's expected in the integration.
    custom_auth_header = invocation_context.headers.pop(
        HEADER_LOCALSTACK_AUTHORIZATION, "")
    invocation_context.headers["Authorization"] = custom_auth_header

    try:
        path_params = extract_path_params(path=relative_path,
                                          extracted_path=resource_path)
        invocation_context.path_params = path_params
    except Exception:
        path_params = {}

    if (uri.startswith("arn:aws:apigateway:")
            and ":lambda:path" in uri) or uri.startswith("arn:aws:lambda"):
        if integration_type == "AWS_PROXY":
            return LambdaProxyIntegration().invoke(invocation_context)
        elif integration_type == "AWS":
            func_arn = uri
            if ":lambda:path" in uri:
                func_arn = (uri.split(":lambda:path")[1].split("functions/")
                            [1].split("/invocations")[0])

            headers = helpers.create_invocation_headers(invocation_context)
            invocation_context.context = helpers.get_event_request_context(
                invocation_context)
            invocation_context.stage_variables = helpers.get_stage_variables(
                invocation_context)
            if invocation_context.authorizer_type:
                invocation_context.context[
                    "authorizer"] = invocation_context.auth_context

            request_templates = RequestTemplates()
            payload = request_templates.render(invocation_context)

            # TODO: change this signature to InvocationContext as well!
            result = lambda_api.process_apigateway_invocation(
                func_arn,
                relative_path,
                payload,
                stage,
                api_id,
                headers,
                is_base64_encoded=invocation_context.is_data_base64_encoded,
                path_params=path_params,
                query_string_params=query_string_params,
                method=method,
                resource_path=resource_path,
                request_context=invocation_context.context,
                stage_variables=invocation_context.stage_variables,
            )

            if isinstance(result, FlaskResponse):
                response = flask_to_requests_response(result)
            elif isinstance(result, Response):
                response = result
            else:
                response = LambdaResponse()
                parsed_result = (result if isinstance(result, dict) else
                                 json.loads(str(result or "{}")))
                parsed_result = common.json_safe(parsed_result)
                parsed_result = {} if parsed_result is None else parsed_result
                response.status_code = 200
                response._content = parsed_result
                update_content_length(response)

            # apply custom response template
            invocation_context.response = response

            response_templates = ResponseTemplates()
            response_templates.render(invocation_context)
            invocation_context.response.headers["Content-Length"] = str(
                len(response.content or ""))
            return invocation_context.response

        raise Exception(
            f'API Gateway integration type "{integration_type}", action "{uri}", method "{method}"'
        )

    elif integration_type == "AWS":
        if "kinesis:action/" in uri:
            if uri.endswith("kinesis:action/PutRecord"):
                target = kinesis_listener.ACTION_PUT_RECORD
            elif uri.endswith("kinesis:action/PutRecords"):
                target = kinesis_listener.ACTION_PUT_RECORDS
            elif uri.endswith("kinesis:action/ListStreams"):
                target = kinesis_listener.ACTION_LIST_STREAMS
            else:
                LOG.info(
                    f"Unexpected API Gateway integration URI '{uri}' for integration type {integration_type}",
                )
                target = ""

            try:
                invocation_context.context = helpers.get_event_request_context(
                    invocation_context)
                invocation_context.stage_variables = helpers.get_stage_variables(
                    invocation_context)
                request_templates = RequestTemplates()
                payload = request_templates.render(invocation_context)

            except Exception as e:
                LOG.warning("Unable to convert API Gateway payload to str", e)
                raise

            # forward records to target kinesis stream
            headers = aws_stack.mock_aws_request_headers(
                service="kinesis", region_name=invocation_context.region_name)
            headers["X-Amz-Target"] = target

            result = common.make_http_request(
                url=config.service_url("kineses"),
                data=payload,
                headers=headers,
                method="POST")

            # apply response template
            invocation_context.response = result
            response_templates = ResponseTemplates()
            response_templates.render(invocation_context)
            return invocation_context.response

        elif "states:action/" in uri:
            action = uri.split("/")[-1]

            if APPLICATION_JSON in integration.get("requestTemplates", {}):
                request_templates = RequestTemplates()
                payload = request_templates.render(invocation_context)
                payload = json.loads(payload)
            else:
                # XXX decoding in py3 sounds wrong, this actually might break
                payload = json.loads(data.decode("utf-8"))
            client = aws_stack.connect_to_service("stepfunctions")

            if isinstance(payload.get("input"), dict):
                payload["input"] = json.dumps(payload["input"])

            # Hot fix since step functions local package responses: Unsupported Operation: 'StartSyncExecution'
            method_name = (camel_to_snake_case(action)
                           if action != "StartSyncExecution" else
                           "start_execution")

            try:
                method = getattr(client, method_name)
            except AttributeError:
                msg = "Invalid step function action: %s" % method_name
                LOG.error(msg)
                return make_error_response(msg, 400)

            result = method(**payload)
            result = json_safe(
                {k: result[k]
                 for k in result if k not in "ResponseMetadata"})
            response = requests_response(
                content=result,
                headers=aws_stack.mock_aws_request_headers(),
            )

            if action == "StartSyncExecution":
                # poll for the execution result and return it
                result = await_sfn_execution_result(result["executionArn"])
                result_status = result.get("status")
                if result_status != "SUCCEEDED":
                    return make_error_response(
                        "StepFunctions execution %s failed with status '%s'" %
                        (result["executionArn"], result_status),
                        500,
                    )
                result = json_safe(result)
                response = requests_response(content=result)

            # apply response templates
            invocation_context.response = response
            response_templates = ResponseTemplates()
            response_templates.render(invocation_context)
            # response = apply_request_response_templates(
            #     response, response_templates, content_type=APPLICATION_JSON
            # )
            return response
        # https://docs.aws.amazon.com/apigateway/api-reference/resource/integration/
        elif ("s3:path/" in uri or "s3:action/" in uri) and method == "GET":
            s3 = aws_stack.connect_to_service("s3")
            uri = apply_request_parameters(
                uri,
                integration=integration,
                path_params=path_params,
                query_params=query_string_params,
            )
            uri_match = re.match(TARGET_REGEX_PATH_S3_URI, uri) or re.match(
                TARGET_REGEX_ACTION_S3_URI, uri)
            if uri_match:
                bucket, object_key = uri_match.group("bucket", "object")
                LOG.debug("Getting request for bucket %s object %s", bucket,
                          object_key)
                try:
                    object = s3.get_object(Bucket=bucket, Key=object_key)
                except s3.exceptions.NoSuchKey:
                    msg = "Object %s not found" % object_key
                    LOG.debug(msg)
                    return make_error_response(msg, 404)

                headers = aws_stack.mock_aws_request_headers(service="s3")

                if object.get("ContentType"):
                    headers["Content-Type"] = object["ContentType"]

                # stream used so large files do not fill memory
                response = request_response_stream(stream=object["Body"],
                                                   headers=headers)
                return response
            else:
                msg = "Request URI does not match s3 specifications"
                LOG.warning(msg)
                return make_error_response(msg, 400)

        if method == "POST":
            if uri.startswith("arn:aws:apigateway:") and ":sqs:path" in uri:
                template = integration["requestTemplates"][APPLICATION_JSON]
                account_id, queue = uri.split("/")[-2:]
                region_name = uri.split(":")[3]
                if "GetQueueUrl" in template or "CreateQueue" in template:
                    request_templates = RequestTemplates()
                    payload = request_templates.render(invocation_context)
                    new_request = f"{payload}&QueueName={queue}"
                else:
                    request_templates = RequestTemplates()
                    payload = request_templates.render(invocation_context)
                    queue_url = f"{config.get_edge_url()}/{account_id}/{queue}"
                    new_request = f"{payload}&QueueUrl={queue_url}"
                headers = aws_stack.mock_aws_request_headers(
                    service="sqs", region_name=region_name)

                url = urljoin(config.service_url("sqs"),
                              f"{TEST_AWS_ACCOUNT_ID}/{queue}")
                result = common.make_http_request(url,
                                                  method="POST",
                                                  headers=headers,
                                                  data=new_request)
                return result
            elif uri.startswith("arn:aws:apigateway:") and ":sns:path" in uri:
                invocation_context.context = helpers.get_event_request_context(
                    invocation_context)
                invocation_context.stage_variables = helpers.get_stage_variables(
                    invocation_context)

                integration_response = SnsIntegration().invoke(
                    invocation_context)
                return apply_request_response_templates(
                    integration_response,
                    response_templates,
                    content_type=APPLICATION_JSON)

        raise Exception(
            'API Gateway AWS integration action URI "%s", method "%s" not yet implemented'
            % (uri, method))

    elif integration_type == "AWS_PROXY":
        if uri.startswith("arn:aws:apigateway:") and ":dynamodb:action" in uri:
            # arn:aws:apigateway:us-east-1:dynamodb:action/PutItem&Table=MusicCollection
            table_name = uri.split(":dynamodb:action")[1].split("&Table=")[1]
            action = uri.split(":dynamodb:action")[1].split("&Table=")[0]

            if "PutItem" in action and method == "PUT":
                response_template = response_templates.get("application/json")

                if response_template is None:
                    msg = "Invalid response template defined in integration response."
                    LOG.info("%s Existing: %s", msg, response_templates)
                    return make_error_response(msg, 404)

                response_template = json.loads(response_template)
                if response_template["TableName"] != table_name:
                    msg = "Invalid table name specified in integration response template."
                    return make_error_response(msg, 404)

                dynamo_client = aws_stack.connect_to_resource("dynamodb")
                table = dynamo_client.Table(table_name)

                event_data = {}
                data_dict = json.loads(data)
                for key, _ in response_template["Item"].items():
                    event_data[key] = data_dict[key]

                table.put_item(Item=event_data)
                response = requests_response(event_data)
                return response
        else:
            raise Exception(
                'API Gateway action uri "%s", integration type %s not yet implemented'
                % (uri, integration_type))

    elif integration_type in ["HTTP_PROXY", "HTTP"]:

        if ":servicediscovery:" in uri:
            # check if this is a servicediscovery integration URI
            client = aws_stack.connect_to_service("servicediscovery")
            service_id = uri.split("/")[-1]
            instances = client.list_instances(
                ServiceId=service_id)["Instances"]
            instance = (instances or [None])[0]
            if instance and instance.get("Id"):
                uri = "http://%s/%s" % (instance["Id"],
                                        invocation_path.lstrip("/"))

        # apply custom request template
        invocation_context.context = helpers.get_event_request_context(
            invocation_context)
        invocation_context.stage_variables = helpers.get_stage_variables(
            invocation_context)
        request_templates = RequestTemplates()
        payload = request_templates.render(invocation_context)

        if isinstance(payload, dict):
            payload = json.dumps(payload)

        uri = apply_request_parameters(
            uri,
            integration=integration,
            path_params=path_params,
            query_params=query_string_params,
        )
        result = requests.request(method=method,
                                  url=uri,
                                  data=payload,
                                  headers=headers)
        # apply custom response template
        invocation_context.response = result
        response_templates = ResponseTemplates()
        response_templates.render(invocation_context)
        return invocation_context.response

    elif integration_type == "MOCK":
        mock_integration = MockIntegration()
        return mock_integration.invoke(invocation_context)

    if method == "OPTIONS":
        # fall back to returning CORS headers if this is an OPTIONS request
        return get_cors_response(headers)

    raise Exception(
        'API Gateway integration type "%s", method "%s", URI "%s" not yet implemented'
        % (integration_type, method, uri))
Esempio n. 14
0
    def invoke(self, invocation_context: ApiInvocationContext):
        uri = (invocation_context.integration.get("uri")
               or invocation_context.integration.get("integrationUri") or "")
        relative_path, query_string_params = extract_query_string_params(
            path=invocation_context.path_with_query_string)
        api_id = invocation_context.api_id
        stage = invocation_context.stage
        headers = invocation_context.headers
        resource_path = invocation_context.resource_path
        invocation_context.context = get_event_request_context(
            invocation_context)
        try:
            path_params = extract_path_params(path=relative_path,
                                              extracted_path=resource_path)
            invocation_context.path_params = path_params
        except Exception:
            path_params = {}

        func_arn = uri
        if ":lambda:path" in uri:
            func_arn = uri.split(":lambda:path")[1].split(
                "functions/")[1].split("/invocations")[0]

        if invocation_context.authorizer_type:
            authorizer_context = {
                invocation_context.authorizer_type:
                invocation_context.auth_context
            }
            invocation_context.context["authorizer"] = authorizer_context

        payload = self.request_templates.render(invocation_context)

        # TODO: change this signature to InvocationContext as well!
        result = lambda_api.process_apigateway_invocation(
            func_arn,
            relative_path,
            payload,
            stage,
            api_id,
            headers,
            is_base64_encoded=invocation_context.is_data_base64_encoded,
            path_params=path_params,
            query_string_params=query_string_params,
            method=invocation_context.method,
            resource_path=resource_path,
            request_context=invocation_context.context,
            stage_variables=invocation_context.stage_variables,
        )

        if isinstance(result, FlaskResponse):
            response = flask_to_requests_response(result)
        elif isinstance(result, Response):
            response = result
        else:
            response = LambdaResponse()
            parsed_result = result if isinstance(result, dict) else json.loads(
                str(result or "{}"))
            parsed_result = common.json_safe(parsed_result)
            parsed_result = {} if parsed_result is None else parsed_result
            response.status_code = int(parsed_result.get("statusCode", 200))
            parsed_headers = parsed_result.get("headers", {})
            if parsed_headers is not None:
                response.headers.update(parsed_headers)
            try:
                result_body = parsed_result.get("body")
                if isinstance(result_body, dict):
                    response._content = json.dumps(result_body)
                else:
                    body_bytes = to_bytes(to_str(result_body or ""))
                    if parsed_result.get("isBase64Encoded", False):
                        body_bytes = base64.b64decode(body_bytes)
                    response._content = body_bytes
            except Exception as e:
                LOG.warning("Couldn't set Lambda response content: %s", e)
                response._content = "{}"
            response.multi_value_headers = parsed_result.get(
                "multiValueHeaders") or {}

        # apply custom response template
        self.update_content_length(response)
        invocation_context.response = response

        self.response_templates.render(invocation_context)
        return invocation_context.response