def test_cors_preflight():
    # GIVEN an event for an OPTIONS call that does not match any of the given routes
    # AND cors is enabled
    app = ApiGatewayResolver(cors=CORSConfig())

    @app.get("/foo")
    def foo_cors():
        ...

    @app.route(method="delete", rule="/foo")
    def foo_delete_cors():
        ...

    @app.post("/foo", cors=False)
    def post_no_cors():
        ...

    # WHEN calling the handler
    result = app({"path": "/foo", "httpMethod": "OPTIONS"}, None)

    # THEN return no content
    # AND include Access-Control-Allow-Methods of the cors methods used
    assert result["statusCode"] == 204
    assert result["body"] is None
    headers = result["headers"]
    assert "Content-Type" not in headers
    assert "Access-Control-Allow-Origin" in result["headers"]
    assert headers["Access-Control-Allow-Methods"] == "DELETE,GET,OPTIONS"
def test_similar_dynamic_routes():
    # GIVEN
    app = ApiGatewayResolver()
    event = deepcopy(LOAD_GW_EVENT)

    # WHEN
    # r'^/accounts/(?P<account_id>\\w+\\b)$' # noqa: E800
    @app.get("/accounts/<account_id>")
    def get_account(account_id: str):
        assert account_id == "single_account"

    # r'^/accounts/(?P<account_id>\\w+\\b)/source_networks$' # noqa: E800
    @app.get("/accounts/<account_id>/source_networks")
    def get_account_networks(account_id: str):
        assert account_id == "nested_account"

    # r'^/accounts/(?P<account_id>\\w+\\b)/source_networks/(?P<network_id>\\w+\\b)$' # noqa: E800
    @app.get("/accounts/<account_id>/source_networks/<network_id>")
    def get_network_account(account_id: str, network_id: str):
        assert account_id == "nested_account"
        assert network_id == "network"

    # THEN
    event["resource"] = "/accounts/{account_id}"
    event["path"] = "/accounts/single_account"
    app.resolve(event, None)

    event["resource"] = "/accounts/{account_id}/source_networks"
    event["path"] = "/accounts/nested_account/source_networks"
    app.resolve(event, None)

    event["resource"] = "/accounts/{account_id}/source_networks/{network_id}"
    event["path"] = "/accounts/nested_account/source_networks/network"
    app.resolve(event, {})
def test_custom_preflight_response():
    # GIVEN cors is enabled
    # AND we have a custom preflight method
    # AND the request matches this custom preflight route
    app = ApiGatewayResolver(cors=CORSConfig())

    @app.route(method="OPTIONS", rule="/some-call", cors=True)
    def custom_preflight():
        return Response(
            status_code=200,
            content_type=content_types.TEXT_HTML,
            body="Foo",
            headers={"Access-Control-Allow-Methods": "CUSTOM"},
        )

    @app.route(method="CUSTOM", rule="/some-call", cors=True)
    def custom_method():
        ...

    # WHEN calling the handler
    result = app({"path": "/some-call", "httpMethod": "OPTIONS"}, None)

    # THEN return the custom preflight response
    assert result["statusCode"] == 200
    assert result["body"] == "Foo"
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_HTML
    assert "Access-Control-Allow-Origin" in result["headers"]
    assert headers["Access-Control-Allow-Methods"] == "CUSTOM"
def test_compress():
    # GIVEN a function that has compress=True
    # AND an event with a "Accept-Encoding" that include gzip
    app = ApiGatewayResolver()
    mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
    expected_value = '{"test": "value"}'

    @app.get("/my/request", compress=True)
    def with_compression() -> Response:
        return Response(200, content_types.APPLICATION_JSON, expected_value)

    def handler(event, context):
        return app.resolve(event, context)

    # WHEN calling the event handler
    result = handler(mock_event, None)

    # THEN then gzip the response and base64 encode as a string
    assert result["isBase64Encoded"] is True
    body = result["body"]
    assert isinstance(body, str)
    decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8")
    assert decompress == expected_value
    headers = result["headers"]
    assert headers["Content-Encoding"] == "gzip"
def test_base64_encode():
    # GIVEN a function that returns bytes
    app = ApiGatewayResolver()
    mock_event = {
        "path": "/my/path",
        "httpMethod": "GET",
        "headers": {
            "Accept-Encoding": "deflate, gzip"
        }
    }

    @app.get("/my/path", compress=True)
    def read_image() -> Response:
        return Response(200, "image/png",
                        read_media("idempotent_sequence_exception.png"))

    # WHEN calling the event handler
    result = app(mock_event, None)

    # THEN return the body and a base64 encoded string
    assert result["isBase64Encoded"] is True
    body = result["body"]
    assert isinstance(body, str)
    headers = result["headers"]
    assert headers["Content-Encoding"] == "gzip"
def test_cors():
    # GIVEN a function with cors=True
    # AND http method set to GET
    app = ApiGatewayResolver()

    @app.get("/my/path", cors=True)
    def with_cors() -> Response:
        return Response(200, content_types.TEXT_HTML, "test")

    @app.get("/without-cors")
    def without_cors() -> Response:
        return Response(200, content_types.TEXT_HTML, "test")

    def handler(event, context):
        return app.resolve(event, context)

    # WHEN calling the event handler
    result = handler(LOAD_GW_EVENT, None)

    # THEN the headers should include cors headers
    assert "headers" in result
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_HTML
    assert headers["Access-Control-Allow-Origin"] == "*"
    assert "Access-Control-Allow-Credentials" not in headers
    assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS))

    # THEN for routes without cors flag return no cors headers
    mock_event = {"path": "/my/request", "httpMethod": "GET"}
    result = handler(mock_event, None)
    assert "Access-Control-Allow-Origin" not in result["headers"]
def test_debug_mode_environment_variable(monkeypatch):
    # GIVEN a debug mode environment variable is set
    monkeypatch.setenv(constants.EVENT_HANDLER_DEBUG_ENV, "true")
    app = ApiGatewayResolver()

    # WHEN calling app._debug
    # THEN the debug mode is enabled
    assert app._debug
def test_debug_print_event(capsys):
    # GIVE debug is True
    app = ApiGatewayResolver(debug=True)

    # WHEN calling resolve
    event = {"path": "/foo", "httpMethod": "GET"}
    app(event, None)

    # THEN print the event
    out, err = capsys.readouterr()
    assert json.loads(out) == event
def test_no_matches_with_cors():
    # GIVEN an event that does not match any of the given routes
    # AND cors enabled
    app = ApiGatewayResolver(cors=CORSConfig())

    # WHEN calling the handler
    result = app({"path": "/another-one", "httpMethod": "GET"}, None)

    # THEN return a 404
    # AND cors headers are returned
    assert result["statusCode"] == 404
    assert "Access-Control-Allow-Origin" in result["headers"]
def test_debug_json_formatting():
    # GIVEN debug is True
    app = ApiGatewayResolver(debug=True)
    response = {"message": "Foo"}

    @app.get("/foo")
    def foo():
        return response

    # WHEN calling the handler
    result = app({"path": "/foo", "httpMethod": "GET"}, None)

    # THEN return a pretty print json in the body
    assert result["body"] == json.dumps(response, indent=4)
def test_no_matches():
    # GIVEN an event that does not match any of the given routes
    app = ApiGatewayResolver()

    @app.get("/not_matching_get")
    def get_func():
        raise RuntimeError()

    @app.post("/no_matching_post")
    def post_func():
        raise RuntimeError()

    @app.put("/no_matching_put")
    def put_func():
        raise RuntimeError()

    @app.delete("/no_matching_delete")
    def delete_func():
        raise RuntimeError()

    @app.patch("/no_matching_patch")
    def patch_func():
        raise RuntimeError()

    def handler(event, context):
        return app.resolve(event, context)

    # Also check check the route configurations
    routes = app._routes
    assert len(routes) == 5
    for route in routes:
        if route.func == get_func:
            assert route.method == "GET"
        elif route.func == post_func:
            assert route.method == "POST"
        elif route.func == put_func:
            assert route.method == "PUT"
        elif route.func == delete_func:
            assert route.method == "DELETE"
        elif route.func == patch_func:
            assert route.method == "PATCH"

    # WHEN calling the handler
    # THEN return a 404
    result = handler(LOAD_GW_EVENT, None)
    assert result["statusCode"] == 404
    # AND cors headers are not returned
    assert "Access-Control-Allow-Origin" not in result["headers"]
def test_compress_no_accept_encoding():
    # GIVEN a function with compress=True
    # AND the request has no "Accept-Encoding" set to include gzip
    app = ApiGatewayResolver()
    expected_value = "Foo"

    @app.get("/my/path", compress=True)
    def return_text() -> Response:
        return Response(200, content_types.TEXT_PLAIN, expected_value)

    # WHEN calling the event handler
    result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None)

    # THEN don't perform any gzip compression
    assert result["isBase64Encoded"] is False
    assert result["body"] == expected_value
def test_include_rule_matching():
    # GIVEN
    app = ApiGatewayResolver()

    @app.get("/<name>/<my_id>")
    def get_lambda(my_id: str, name: str) -> Response:
        assert name == "my"
        return Response(200, content_types.TEXT_HTML, my_id)

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
    assert result["body"] == "path"
def test_non_word_chars_route(req):
    # GIVEN
    app = ApiGatewayResolver()
    event = deepcopy(LOAD_GW_EVENT)

    # WHEN
    @app.get("/accounts/<account_id>")
    def get_account(account_id: str):
        assert account_id == f"{req}"

    # THEN
    event["resource"] = "/accounts/{account_id}"
    event["path"] = f"/accounts/{req}"

    ret = app.resolve(event, None)
    assert ret["statusCode"] == 200
def test_custom_cors_config():
    # GIVEN a custom cors configuration
    allow_header = ["foo2"]
    cors_config = CORSConfig(
        allow_origin="https://foo1",
        expose_headers=["foo1"],
        allow_headers=allow_header,
        max_age=100,
        allow_credentials=True,
    )
    app = ApiGatewayResolver(cors=cors_config)
    event = {"path": "/cors", "httpMethod": "GET"}

    @app.get("/cors")
    def get_with_cors():
        return {}

    @app.get("/another-one", cors=False)
    def another_one():
        return {}

    # WHEN calling the event handler
    result = app(event, None)

    # THEN routes by default return the custom cors headers
    assert "headers" in result
    headers = result["headers"]
    assert headers["Content-Type"] == APPLICATION_JSON
    assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin
    expected_allows_headers = ",".join(
        sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))
    assert headers["Access-Control-Allow-Headers"] == expected_allows_headers
    assert headers["Access-Control-Expose-Headers"] == ",".join(
        cors_config.expose_headers)
    assert headers["Access-Control-Max-Age"] == str(cors_config.max_age)
    assert "Access-Control-Allow-Credentials" in headers
    assert headers["Access-Control-Allow-Credentials"] == "true"

    # AND custom cors was set on the app
    assert isinstance(app._cors, CORSConfig)
    assert app._cors is cors_config

    # AND routes without cors don't include "Access-Control" headers
    event = {"path": "/another-one", "httpMethod": "GET"}
    result = app(event, None)
    headers = result["headers"]
    assert "Access-Control-Allow-Origin" not in headers
def test_api_gateway():
    # GIVEN a Rest API Gateway proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

    @app.get("/my/path")
    def get_lambda() -> Response:
        assert isinstance(app.current_event, APIGatewayProxyEvent)
        return Response(200, content_types.TEXT_HTML, "foo")

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN process event correctly
    # AND set the current_event type as APIGatewayProxyEvent
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
    assert result["body"] == "foo"
def test_api_gateway_v1():
    # GIVEN a Http API V1 proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

    @app.get("/my/path")
    def get_lambda() -> Response:
        assert isinstance(app.current_event, APIGatewayProxyEvent)
        assert app.lambda_context == {}
        return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"}))

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN process event correctly
    # AND set the current_event type as APIGatewayProxyEvent
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
def test_rest_api():
    # GIVEN a function that returns a Dict
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)
    expected_dict = {"foo": "value", "second": Decimal("100.01")}

    @app.get("/my/path")
    def rest_func() -> Dict:
        return expected_dict

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN automatically process this as a json rest api response
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder)
    assert result["body"] == expected_str
def test_debug_unhandled_exceptions_debug_off():
    # GIVEN debug is disabled
    # AND an unhandled exception is raised
    app = ApiGatewayResolver(debug=False)
    assert not app._debug

    @app.get("/raises-error")
    def raises_error():
        raise RuntimeError("Foo")

    # WHEN calling the handler
    # THEN raise the original exception
    with pytest.raises(RuntimeError) as e:
        app({"path": "/raises-error", "httpMethod": "GET"}, None)

    # AND include the original error
    assert e.value.args == ("Foo",)
def test_alb_event():
    # GIVEN a Application Load Balancer proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.ALBEvent)

    @app.get("/lambda")
    def foo():
        assert isinstance(app.current_event, ALBEvent)
        assert app.lambda_context == {}
        return Response(200, content_types.TEXT_HTML, "foo")

    # WHEN calling the event handler
    result = app(load_event("albEvent.json"), {})

    # THEN process event correctly
    # AND set the current_event type as ALBEvent
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
    assert result["body"] == "foo"
def test_api_gateway_v2():
    # GIVEN a Http API V2 proxy type event
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEventV2)

    @app.post("/my/path")
    def my_path() -> Response:
        assert isinstance(app.current_event, APIGatewayProxyEventV2)
        post_data = app.current_event.json_body
        return Response(200, content_types.TEXT_PLAIN, post_data["username"])

    # WHEN calling the event handler
    result = app(load_event("apiGatewayProxyV2Event.json"), {})

    # THEN process event correctly
    # AND set the current_event type as APIGatewayProxyEventV2
    assert result["statusCode"] == 200
    assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
    assert result["body"] == "tom"
def test_cache_control_200():
    # GIVEN a function with cache_control set
    app = ApiGatewayResolver()

    @app.get("/success", cache_control="max-age=600")
    def with_cache_control() -> Response:
        return Response(200, content_types.TEXT_HTML, "has 200 response")

    def handler(event, context):
        return app.resolve(event, context)

    # WHEN calling the event handler
    # AND the function returns a 200 status code
    result = handler({"path": "/success", "httpMethod": "GET"}, None)

    # THEN return the set Cache-Control
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_HTML
    assert headers["Cache-Control"] == "max-age=600"
def test_cache_control_non_200():
    # GIVEN a function with cache_control set
    app = ApiGatewayResolver()

    @app.delete("/fails", cache_control="max-age=600")
    def with_cache_control_has_500() -> Response:
        return Response(503, content_types.TEXT_HTML, "has 503 response")

    def handler(event, context):
        return app.resolve(event, context)

    # WHEN calling the event handler
    # AND the function returns a 503 status code
    result = handler({"path": "/fails", "httpMethod": "DELETE"}, None)

    # THEN return a Cache-Control of "no-cache"
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_HTML
    assert headers["Cache-Control"] == "no-cache"
def test_debug_unhandled_exceptions_debug_on():
    # GIVEN debug is enabled
    # AND an unhandled exception is raised
    app = ApiGatewayResolver(debug=True)
    assert app._debug

    @app.get("/raises-error")
    def raises_error():
        raise RuntimeError("Foo")

    # WHEN calling the handler
    result = app({"path": "/raises-error", "httpMethod": "GET"}, None)

    # THEN return a 500
    # AND Content-Type is set to text/plain
    # AND include the exception traceback in the response
    assert result["statusCode"] == 500
    assert "Traceback (most recent call last)" in result["body"]
    headers = result["headers"]
    assert headers["Content-Type"] == content_types.TEXT_PLAIN
def test_handling_response_type():
    # GIVEN a function that returns Response
    app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

    @app.get("/my/path")
    def rest_func() -> Response:
        return Response(
            status_code=404,
            content_type="used-if-not-set-in-header",
            body="Not found",
            headers={"Content-Type": "header-content-type-wins", "custom": "value"},
        )

    # WHEN calling the event handler
    result = app(LOAD_GW_EVENT, {})

    # THEN the result can include some additional field control like overriding http headers
    assert result["statusCode"] == 404
    assert result["headers"]["Content-Type"] == "header-content-type-wins"
    assert result["headers"]["custom"] == "value"
    assert result["body"] == "Not found"
import json
import os

import boto3
from aws_lambda_powertools import Logger, Metrics, Tracer
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

# https://awslabs.github.io/aws-lambda-powertools-python/#features
tracer = Tracer()
logger = Logger()
metrics = Metrics()
app = ApiGatewayResolver()

# Global variables are reused across execution contexts (if available)
# session = boto3.Session()

@app.get("/hello")
def hello():
    query_string_name = app.current_event.get_query_string_value(name="name", default_value="universe")
    return {"message": f"hello {query_string_name}"}


@app.get("/hello/<name>")
def hello_you(name):
    # query_strings_as_dict = app.current_event.query_string_parameters
    # json_payload = app.current_event.json_body
    return {"message": f"hello {name}"}

@metrics.log_metrics(capture_cold_start_metric=True)
def test_service_error_responses():
    # SCENARIO handling different kind of service errors being raised
    app = ApiGatewayResolver(cors=CORSConfig())

    def json_dump(obj):
        return json.dumps(obj, separators=(",", ":"))

    # GIVEN an BadRequestError
    @app.get(rule="/bad-request-error", cors=False)
    def bad_request_error():
        raise BadRequestError("Missing required parameter")

    # WHEN calling the handler
    # AND path is /bad-request-error
    result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None)
    # THEN return the bad request error response
    # AND status code equals 400
    assert result["statusCode"] == 400
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 400, "message": "Missing required parameter"}
    assert result["body"] == json_dump(expected)

    # GIVEN an UnauthorizedError
    @app.get(rule="/unauthorized-error", cors=False)
    def unauthorized_error():
        raise UnauthorizedError("Unauthorized")

    # WHEN calling the handler
    # AND path is /unauthorized-error
    result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None)
    # THEN return the unauthorized error response
    # AND status code equals 401
    assert result["statusCode"] == 401
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 401, "message": "Unauthorized"}
    assert result["body"] == json_dump(expected)

    # GIVEN an NotFoundError
    @app.get(rule="/not-found-error", cors=False)
    def not_found_error():
        raise NotFoundError

    # WHEN calling the handler
    # AND path is /not-found-error
    result = app({"path": "/not-found-error", "httpMethod": "GET"}, None)
    # THEN return the not found error response
    # AND status code equals 404
    assert result["statusCode"] == 404
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 404, "message": "Not found"}
    assert result["body"] == json_dump(expected)

    # GIVEN an InternalServerError
    @app.get(rule="/internal-server-error", cors=False)
    def internal_server_error():
        raise InternalServerError("Internal server error")

    # WHEN calling the handler
    # AND path is /internal-server-error
    result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None)
    # THEN return the internal server error response
    # AND status code equals 500
    assert result["statusCode"] == 500
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    expected = {"statusCode": 500, "message": "Internal server error"}
    assert result["body"] == json_dump(expected)

    # GIVEN an ServiceError with a custom status code
    @app.get(rule="/service-error", cors=True)
    def service_error():
        raise ServiceError(502, "Something went wrong!")

    # WHEN calling the handler
    # AND path is /service-error
    result = app({"path": "/service-error", "httpMethod": "GET"}, None)
    # THEN return the service error response
    # AND status code equals 502
    assert result["statusCode"] == 502
    assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
    assert "Access-Control-Allow-Origin" in result["headers"]
    expected = {"statusCode": 502, "message": "Something went wrong!"}
    assert result["body"] == json_dump(expected)