Example #1
0
def test_cors_strategy_raises_for_duplicate_policy_name():
    cors = CORSStrategy(CORSPolicy(), Router())

    cors.add_policy("a", CORSPolicy())

    with pytest.raises(CORSConfigurationError):
        cors.add_policy("a", CORSPolicy())
Example #2
0
def test_cors_policy_raises_for_negative_max_age():
    with pytest.raises(ValueError):
        CORSPolicy(max_age=-1)

    policy = CORSPolicy()
    with pytest.raises(ValueError):
        policy.max_age = -5
Example #3
0
def test_cors_strategy_raises_for_missing_policy_name():
    cors = CORSStrategy(CORSPolicy(), Router())

    with pytest.raises(CORSConfigurationError):
        cors.add_policy("", CORSPolicy())

    with pytest.raises(CORSConfigurationError):
        cors.add_policy(None, CORSPolicy())  # type: ignore
Example #4
0
def test_cors_policy_setters_force_case():
    policy = CORSPolicy()

    policy.allow_methods = ["get", "delete"]
    assert policy.allow_methods == {"GET", "DELETE"}

    policy.allow_headers = ["X-Foo", "Authorization"]
    assert policy.allow_headers == {"x-foo", "authorization"}

    policy.allow_origins = ["http://Example.com", "https://Bezkitu.ORG"]
    assert policy.allow_origins == {
        "http://example.com", "https://bezkitu.org"
    }
Example #5
0
def test_cors_policy():
    policy = CORSPolicy(
        allow_methods=["GET", "POST", "DELETE"],
        allow_headers=["Authorization"],
        allow_origins=["http://localhost:44555"],
    )
    assert policy.allow_methods == {"GET", "POST", "DELETE"}
    assert policy.allow_headers == {"authorization"}
    assert policy.allow_origins == {"http://localhost:44555"}
Example #6
0
def test_cors_policy_allow_all_methods():
    policy = CORSPolicy()

    assert policy.allow_headers == set()
    policy.allow_any_header()
    assert policy.allow_headers == {"*"}

    assert policy.allow_methods == set()
    policy.allow_any_method()
    assert policy.allow_methods == {"*"}

    assert policy.allow_origins == set()
    policy.allow_any_origin()
    assert policy.allow_origins == {"*"}
Example #7
0
    def add_cors_policy(
        self,
        policy_name,
        *,
        allow_methods: Union[None, str, Iterable[str]] = None,
        allow_headers: Union[None, str, Iterable[str]] = None,
        allow_origins: Union[None, str, Iterable[str]] = None,
        allow_credentials: bool = False,
        max_age: int = 5,
        expose_headers: Union[None, str, Iterable[str]] = None,
    ) -> None:
        """
        Configures a set of CORS rules that can later be applied to specific request
        handlers, by name.

        The CORS policy can then be associated to specific request handlers,
        using the instance of `CORSStrategy` as a function decorator:

        @app.cors("example")
        @app.route("/")
        async def foo():
            ....
        """
        if self.started:
            raise ApplicationAlreadyStartedCORSError()

        if not self._cors_strategy:
            self.use_cors()

        self._cors_strategy.add_policy(
            policy_name,
            CORSPolicy(
                allow_methods=allow_methods,
                allow_headers=allow_headers,
                allow_origins=allow_origins,
                allow_credentials=allow_credentials,
                max_age=max_age,
                expose_headers=expose_headers,
            ),
        )
Example #8
0
    def use_cors(
        self,
        *,
        allow_methods: Union[None, str, Iterable[str]] = None,
        allow_headers: Union[None, str, Iterable[str]] = None,
        allow_origins: Union[None, str, Iterable[str]] = None,
        allow_credentials: bool = False,
        max_age: int = 5,
        expose_headers: Union[None, str, Iterable[str]] = None,
    ) -> CORSStrategy:
        """
        Enables CORS for the application, specifying the default rules to be applied
        for all request handlers.
        """
        if self.started:
            raise ApplicationAlreadyStartedCORSError()
        self._cors_strategy = CORSStrategy(
            CORSPolicy(
                allow_methods=allow_methods,
                allow_headers=allow_headers,
                allow_origins=allow_origins,
                allow_credentials=allow_credentials,
                max_age=max_age,
                expose_headers=expose_headers,
            ),
            self.router,
        )

        # Note: the following is a no-op request handler, necessary to activate handling
        # of OPTIONS preflight requests.
        # However, preflight requests are handled by the CORS middleware. This is to
        # stop the chain of middlewares and prevent extra logic from executing for
        # preflight requests (e.g. authentication logic)
        @self.router.options("*")
        async def options_handler(request):
            return Response(404)

        # User defined catch-all OPTIONS request handlers are not supported when the
        # built-in CORS handler is used.
        return self._cors_strategy
Example #9
0
def test_cors_policy_setters_strings():
    policy = CORSPolicy()

    policy.allow_methods = "get delete"
    assert policy.allow_methods == {"GET", "DELETE"}

    policy.allow_methods = "GET POST PATCH"
    assert policy.allow_methods == {"GET", "POST", "PATCH"}

    policy.allow_methods = "GET, POST, PATCH"
    assert policy.allow_methods == {"GET", "POST", "PATCH"}

    policy.allow_methods = "GET,POST,PATCH"
    assert policy.allow_methods == {"GET", "POST", "PATCH"}

    policy.allow_methods = "GET;POST;PATCH"
    assert policy.allow_methods == {"GET", "POST", "PATCH"}

    for value in {
            "X-Foo Authorization", "X-Foo, Authorization",
            "X-Foo,Authorization"
    }:
        policy.allow_headers = value
        assert policy.allow_headers == {"x-foo", "authorization"}

    policy.allow_origins = "http://Example.com https://Bezkitu.ORG"
    assert policy.allow_origins == {
        "http://example.com", "https://bezkitu.org"
    }

    policy.allow_headers = None
    assert policy.allow_headers == frozenset()

    policy.allow_methods = None
    assert policy.allow_methods == frozenset()

    policy.allow_origins = None
    assert policy.allow_origins == frozenset()