def test_api_gateway():
    # GIVEN a Rest API Gateway proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

    @app.get("/my/path")
    def get_lambda() -> Response:
        assert isinstance(app.current_event, APIGatewayProxyEvent)
        return Response(200, content_types.TEXT_HTML, "foo")

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN process event correctly
    # AND set the current_event type as APIGatewayProxyEvent
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
    assert result["body"] == "foo"
def test_api_gateway_v1():
    # GIVEN a Http API V1 proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

    @app.get("/my/path")
    def get_lambda() -> Response:
        assert isinstance(app.current_event, APIGatewayProxyEvent)
        assert app.lambda_context == {}
        return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"}))

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN process event correctly
    # AND set the current_event type as APIGatewayProxyEvent
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
def test_debug_unhandled_exceptions_debug_off():
    # GIVEN debug is disabled
    # AND an unhandled exception is raised
    app = ApiGatewayResolver(debug=False)
    assert not app._debug

    @app.get("/raises-error")
    def raises_error():
        raise RuntimeError("Foo")

    # WHEN calling the handler
    # THEN raise the original exception
    with pytest.raises(RuntimeError) as e:
        app({"path": "/raises-error", "httpMethod": "GET"}, None)

    # AND include the original error
    assert e.value.args == ("Foo",)
def test_rest_api():
    # GIVEN a function that returns a Dict
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)
    expected_dict = {"foo": "value", "second": Decimal("100.01")}

    @app.get("/my/path")
    def rest_func() -> Dict:
        return expected_dict

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN automatically process this as a json rest api response
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder)
    assert result["body"] == expected_str
def test_custom_cors_config():
    # GIVEN a custom cors configuration
    allow_header = ["foo2"]
    cors_config = CORSConfig(
        allow_origin="https://foo1",
        expose_headers=["foo1"],
        allow_headers=allow_header,
        max_age=100,
        allow_credentials=True,
    )
    app = ApiGatewayResolver(cors=cors_config)
    event = {"path": "/cors", "httpMethod": "GET"}

    @app.get("/cors")
    def get_with_cors():
        return {}

    @app.get("/another-one", cors=False)
    def another_one():
        return {}

    # WHEN calling the event handler
    result = app(event, None)

    # THEN routes by default return the custom cors headers
    assert "headers" in result
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.APPLICATION_JSON
    assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin
    expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))
    assert headers["Access-Control-Allow-Headers"] == expected_allows_headers
    assert headers["Access-Control-Expose-Headers"] == ",".join(cors_config.expose_headers)
    assert headers["Access-Control-Max-Age"] == str(cors_config.max_age)
    assert "Access-Control-Allow-Credentials" in headers
    assert headers["Access-Control-Allow-Credentials"] == "true"

    # AND custom cors was set on the app
    assert isinstance(app._cors, CORSConfig)
    assert app._cors is cors_config

    # AND routes without cors don't include "Access-Control" headers
    event = {"path": "/another-one", "httpMethod": "GET"}
    result = app(event, None)
    headers = result["headers"]
    assert "Access-Control-Allow-Origin" not in headers
def test_api_gateway_v2():
    # GIVEN a Http API V2 proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEventV2)

    @app.post("/my/path")
    def my_path() -> Response:
        assert isinstance(app.current_event, APIGatewayProxyEventV2)
        post_data = app.current_event.json_body
        return Response(200, content_types.TEXT_PLAIN, post_data["username"])

    # WHEN calling the event handler
    result = app(load_event("apiGatewayProxyV2Event.json"), {})

    # THEN process event correctly
    # AND set the current_event type as APIGatewayProxyEventV2
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
    assert result["body"] == "tom"
def test_alb_event():
    # GIVEN a Application Load Balancer proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.ALBEvent)

    @app.get("/lambda")
    def foo():
        assert isinstance(app.current_event, ALBEvent)
        assert app.lambda_context == {}
        return Response(200, content_types.TEXT_HTML, "foo")

    # WHEN calling the event handler
    result = app(load_event("albEvent.json"), {})

    # THEN process event correctly
    # AND set the current_event type as ALBEvent
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
    assert result["body"] == "foo"
def test_base64_encode():
    # GIVEN a function that returns bytes
    app = ApiGatewayResolver()
    mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}

    @app.get("/my/path", compress=True)
    def read_image() -> Response:
        return Response(200, "image/png", read_media("idempotent_sequence_exception.png"))

    # WHEN calling the event handler
    result = app(mock_event, None)

    # THEN return the body and a base64 encoded string
    assert result["isBase64Encoded"] is True
    body = result["body"]
    assert isinstance(body, str)
    headers = result["headers"]
    assert headers["Content-Encoding"] == "gzip"
def test_cache_control_non_200():
    # GIVEN a function with cache_control set
    app = ApiGatewayResolver()

    @app.delete("/fails", cache_control="max-age=600")
    def with_cache_control_has_500() -> Response:
        return Response(503, content_types.TEXT_HTML, "has 503 response")

    def handler(event, context):
        return app.resolve(event, context)

    # WHEN calling the event handler
    # AND the function returns a 503 status code
    result = handler({"path": "/fails", "httpMethod": "DELETE"}, None)

    # THEN return a Cache-Control of "no-cache"
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_HTML
    assert headers["Cache-Control"] == "no-cache"
def test_cache_control_200():
    # GIVEN a function with cache_control set
    app = ApiGatewayResolver()

    @app.get("/success", cache_control="max-age=600")
    def with_cache_control() -> Response:
        return Response(200, content_types.TEXT_HTML, "has 200 response")

    def handler(event, context):
        return app.resolve(event, context)

    # WHEN calling the event handler
    # AND the function returns a 200 status code
    result = handler({"path": "/success", "httpMethod": "GET"}, None)

    # THEN return the set Cache-Control
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_HTML
    assert headers["Cache-Control"] == "max-age=600"
def test_debug_unhandled_exceptions_debug_on():
    # GIVEN debug is enabled
    # AND an unhandled exception is raised
    app = ApiGatewayResolver(debug=True)
    assert app._debug

    @app.get("/raises-error")
    def raises_error():
        raise RuntimeError("Foo")

    # WHEN calling the handler
    result = app({"path": "/raises-error", "httpMethod": "GET"}, None)

    # THEN return a 500
    # AND Content-Type is set to text/plain
    # AND include the exception traceback in the response
    assert result["statusCode"] == 500
    assert "Traceback (most recent call last)" in result["body"]
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_PLAIN
def test_compress_no_accept_encoding():
    # GIVEN a function with compress=True
    # AND the request has no "Accept-Encoding" set to include gzip
    app = ApiGatewayResolver()
    expected_value = "Foo"

    @app.get("/my/path", compress=True)
    def return_text() -> Response:
        return Response(200, "text/plain", expected_value)

    # WHEN calling the event handler
    result = app({
        "path": "/my/path",
        "httpMethod": "GET",
        "headers": {}
    }, None)

    # THEN don't perform any gzip compression
    assert result["isBase64Encoded"] is False
    assert result["body"] == expected_value
def test_handling_response_type():
    # GIVEN a function that returns Response
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

    @app.get("/my/path")
    def rest_func() -> Response:
        return Response(
            status_code=404,
            content_type="used-if-not-set-in-header",
            body="Not found",
            headers={"Content-Type": "header-content-type-wins", "custom": "value"},
        )

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN the result can include some additional field control like overriding http headers
    assert result["statusCode"] == 404
    assert result["headers"]["Content-Type"] == "header-content-type-wins"
    assert result["headers"]["custom"] == "value"
    assert result["body"] == "Not found"
def test_service_error_responses():
    # SCENARIO handling different kind of service errors being raised
    app = ApiGatewayResolver(cors=CORSConfig())

    def json_dump(obj):
        return json.dumps(obj, separators=(",", ":"))

    # GIVEN an BadRequestError
    @app.get(rule="/bad-request-error", cors=False)
    def bad_request_error():
        raise BadRequestError("Missing required parameter")

    # WHEN calling the handler
    # AND path is /bad-request-error
    result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None)
    # THEN return the bad request error response
    # AND status code equals 400
    assert result["statusCode"] == 400
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 400, "message": "Missing required parameter"}
    assert result["body"] == json_dump(expected)

    # GIVEN an UnauthorizedError
    @app.get(rule="/unauthorized-error", cors=False)
    def unauthorized_error():
        raise UnauthorizedError("Unauthorized")

    # WHEN calling the handler
    # AND path is /unauthorized-error
    result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None)
    # THEN return the unauthorized error response
    # AND status code equals 401
    assert result["statusCode"] == 401
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 401, "message": "Unauthorized"}
    assert result["body"] == json_dump(expected)

    # GIVEN an NotFoundError
    @app.get(rule="/not-found-error", cors=False)
    def not_found_error():
        raise NotFoundError

    # WHEN calling the handler
    # AND path is /not-found-error
    result = app({"path": "/not-found-error", "httpMethod": "GET"}, None)
    # THEN return the not found error response
    # AND status code equals 404
    assert result["statusCode"] == 404
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 404, "message": "Not found"}
    assert result["body"] == json_dump(expected)

    # GIVEN an InternalServerError
    @app.get(rule="/internal-server-error", cors=False)
    def internal_server_error():
        raise InternalServerError("Internal server error")

    # WHEN calling the handler
    # AND path is /internal-server-error
    result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None)
    # THEN return the internal server error response
    # AND status code equals 500
    assert result["statusCode"] == 500
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 500, "message": "Internal server error"}
    assert result["body"] == json_dump(expected)

    # GIVEN an ServiceError with a custom status code
    @app.get(rule="/service-error", cors=True)
    def service_error():
        raise ServiceError(502, "Something went wrong!")

    # WHEN calling the handler
    # AND path is /service-error
    result = app({"path": "/service-error", "httpMethod": "GET"}, None)
    # THEN return the service error response
    # AND status code equals 502
    assert result["statusCode"] == 502
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    assert "Access-Control-Allow-Origin" in result["headers"]
    expected = {"statusCode": 502, "message": "Something went wrong!"}
    assert result["body"] == json_dump(expected)
import json
import os

import boto3
from aws_lambda_powertools import Logger, Metrics, Tracer
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

# https://awslabs.github.io/aws-lambda-powertools-python/#features
tracer = Tracer()
logger = Logger()
metrics = Metrics()
app = ApiGatewayResolver()

# Global variables are reused across execution contexts (if available)
# session = boto3.Session()

@app.get("/hello")
def hello():
    query_string_name = app.current_event.get_query_string_value(name="name", default_value="universe")
    return {"message": f"hello {query_string_name}"}


@app.get("/hello/<name>")
def hello_you(name):
    # query_strings_as_dict = app.current_event.query_string_parameters
    # json_payload = app.current_event.json_body
    return {"message": f"hello {name}"}

@metrics.log_metrics(capture_cold_start_metric=True)