def get_app(): global _app if not _app: _app = CustomApp( title="FastAPI backend template", version="0.0.4", description="", exception_handlers=None, middleware=( Middleware(PrometheusMiddleware), Middleware( CORSMiddleware, allow_origins=ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ), ), ) _app.include_router(service_route) _app.include_router(router) for _mount in MOUNTS: _app.mount(*_mount) if settings.SENTRY_DSN: sentry_sdk.init(dsn=settings.SENTRY_DSN) _app.add_middleware(SentryAsgiMiddleware) db.init_app(_app) return _app
def build_middleware_stack(self) -> ASGIApp: debug = self.debug error_handler = None exception_handlers = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): error_handler = value else: exception_handlers[key] = value server_errors = Middleware( ServerErrorMiddleware, options={ "handler": error_handler, "debug": debug }, ) exceptions = Middleware( ExceptionMiddleware, options={ "handlers": exception_handlers, "debug": debug }, ) middleware = [server_errors] + self.user_middleware + [exceptions] app = self.router for cls, options, enabled in reversed(middleware): if enabled: app = cls(app=app, **options) return app
def _make_csrf_app(**options): app = Starlette(middleware=[ Middleware(SessionMiddleware, secret_key='xxx'), Middleware(CSRFProtectMiddleware, csrf_secret='yyy', **options) ]) @app.route('/', methods=['GET', 'POST']) async def index(request): return PlainTextResponse() @app.route('/token', methods=['GET']) async def token(request): token = csrf_token(request) return PlainTextResponse(token) @app.route('/new-token', methods=['GET']) async def new_token(request): if 'csrf_token' in request.session: request.session.pop('csrf_token') token = csrf_token(request) return PlainTextResponse(token) client = TestClient(app) return app, client
def create_app(): from .config.application import DEBUG from .views.urls import routes from .clients.web_scraper import WebScraper from .utils.exception_handler import http_exception_handler exception_handlers = { HTTPException: http_exception_handler, # Exception: http_exception } middleware = [ Middleware(GZipMiddleware), Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ), ] app = Starlette( routes=routes, debug=DEBUG, exception_handlers=exception_handlers, middleware=middleware, on_startup=[start_services], on_shutdown=[stop_services], ) return app
def setup_app(settings: Settings) -> FastAPI: middlewares = [ Middleware( CORSMiddleware, allow_origins=add_origins(), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) ] if settings.sentry_dsn: sentry_sdk.init(dsn=settings.sentry_dsn, release=__version__) middlewares.append(Middleware(SentryAsgiMiddleware)) customized_logger = CastomizeLogger.make_logger() app = FastAPI( title=__module_name__, version=__version__, description="Grabbers for nothing", middleware=middlewares, **{SETTINGS_KEY: settings, LOGGER_KEY: customized_logger}, # type: ignore ) app = register_routers(app) return app
def build_middleware_stack(self) -> ASGIApp: debug = self.debug error_handler = None exception_handlers = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): error_handler = value else: exception_handlers[key] = value middleware = ([ Middleware( ServerErrorMiddleware, handler=error_handler, debug=debug, ) ] + self.user_middleware + [ Middleware( ExceptionMiddleware, handlers=exception_handlers, debug=debug, ) ]) app = self.router for cls, options in reversed(middleware): app = cls(app=app, **options) return app
def build_app(self) -> ASGIApp: error_handler = None exception_handlers = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): error_handler = value else: exception_handlers[key] = value middlewares = [] middlewares.append( Middleware(ServerErrorMiddleware, handler=error_handler, debug=self.debug)) middlewares += self.user_middlewares middlewares.append( Middleware(ExceptionMiddleware, handlers=exception_handlers)) app = self.app for cls, options in reversed(middlewares): app = cls(app=app, **options) return app
def populate_middlewares(self, auth_token_verify_user_callback=None, cors=True, debug=False) -> typing.List[Middleware]: middlewares = [] if auth_token_verify_user_callback: token_class = build_token_backend(auth_token_verify_user_callback) middlewares.append( Middleware( AuthenticationMiddleware, backend=token_class(), on_error=on_auth_error, )) if cors: middlewares.append( Middleware( CORSMiddleware, allow_methods=["*"], allow_origins=["*"], allow_headers=["*"], )) if not debug: if self.sentry_dsn: from sentry_sdk.integrations.asgi import SentryAsgiMiddleware middlewares.append(Middleware(SentryAsgiMiddleware)) print("Adding Sentry middleware to application") return middlewares
def middleware(prod): return [ Middleware(HTTPSRedirectMiddleware), Middleware(SessionMiddleware, **SessionMiddlewareArgs), Middleware(AuthenticationMiddleware, backend=OdooBasicAuthBackendAsync()), Middleware(OdooEnvironmentMiddlewareAsync), ] + ([Middleware(PrometheusMiddleware)] if prod else [])
def create_middleware(): middleware = list() middleware.append( Middleware(SessionMiddleware, secret_key=settings.secret_key) ) middleware.append(Middleware(SessionIDMiddleware)) return middleware
def init_with_instana(wrapped, instance, args, kwargs): middleware = kwargs.get('middleware') if middleware is None: kwargs['middleware'] = [Middleware(InstanaASGIMiddleware)] elif isinstance(middleware, list): middleware.append(Middleware(InstanaASGIMiddleware)) return wrapped(*args, **kwargs)
def _register_middleware(self, config: AppConfig) -> Sequence[Middleware]: middleware: Sequence[Middleware] = [ Middleware(TrustedHostMiddleware, allowed_hosts=config.ALLOWED_HOSTS), Middleware(CORSMiddleware, allow_origins=config.ALLOW_ORIGINS), Middleware(ContextMiddleware.with_plugins(RequestIdPlugin)), Middleware(ExceptHandlerMiddleware), Middleware(SQLAlchemySessionMiddleware, alchemy=self._database_adapter), # TODO 請加入 HTTPSRedirectMiddleware 與 SessionMiddleware (若需要 secret_key,已放在參考設定檔) ] return middleware
def create_application( database_url: str, user_pool_emulator_url_base: str, debug: bool = False, max_pool_workers: int = 10, region: str = "mars-east-1", prepended_routes: typing.Iterable[BaseRoute] = [], ): from .views import admin as admin_views from .views import index as index_views logging.basicConfig( level=logging.DEBUG if debug else logging.INFO, force=True, ) if max_pool_workers > 0: executor.executor = concurrent.futures.ThreadPoolExecutor( max_workers=max_pool_workers) session_factory.configure(bind=sa.create_engine(database_url), ) routes: typing.List[BaseRoute] = list(prepended_routes) routes += [ Mount( "/static", app=StaticFiles(directory=str(basedir / "static")), name="static", ), Mount( "/admin", name="admin", app=admin_views.routes, ), Mount("/", index_views.routes), ] app = Starlette( debug=debug, routes=routes, middleware=[ Middleware(RequestTimeMiddleware), Middleware(SQLAlchemyMiddleware), Middleware(TemplateShortcutMiddleware), ], on_shutdown=[ lambda: (executor.executor.shutdown(wait=True) if executor.executor is not None else None), ], ) app.state.user_pool_emulator_url_base = user_pool_emulator_url_base app.state.region = region app.state.templates = Jinja2Templates(directory=str(basedir / "templates")) app.state.uuidgen = lambda: str(uuid.uuid4()) return app
def create_app(cfg: Config, loop: asyncio.AbstractEventLoop) -> Starlette: crawler = Crawler(cfg.TG_SESSION, cfg.API_ID, cfg.API_HASH, cfg.PROXY) middlewares = [ Middleware(CORSMiddleware, allow_origins=["*"]), Middleware(CrawlerInjectMiddleware, crawler=crawler), ] asyncio.run_coroutine_threadsafe(crawler.start_poll(), loop) app = Starlette( debug=cfg.DEBUG, routes=routes(), middleware=middlewares, exception_handlers={HTTPException: _exception_handler}, ) return app
def build_middleware_stack(self) -> ASGIApp: # Duplicate/override from Starlette to add AsyncExitStackMiddleware # inside of ExceptionMiddleware, inside of custom user middlewares debug = self.debug error_handler = None exception_handlers = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): error_handler = value else: exception_handlers[key] = value middleware = ( [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + self.user_middleware + [ Middleware( ExceptionMiddleware, handlers=exception_handlers, debug=debug ), # Add FastAPI-specific AsyncExitStackMiddleware for dependencies with # contextvars. # This needs to happen after user middlewares because those create a # new contextvars context copy by using a new AnyIO task group. # The initial part of dependencies with yield is executed in the # FastAPI code, inside all the middlewares, but the teardown part # (after yield) is executed in the AsyncExitStack in this middleware, # if the AsyncExitStack lived outside of the custom middlewares and # contextvars were set in a dependency with yield in that internal # contextvars context, the values would not be available in the # outside context of the AsyncExitStack. # By putting the middleware and the AsyncExitStack here, inside all # user middlewares, the code before and after yield in dependencies # with yield is executed in the same contextvars context, so all values # set in contextvars before yield is still available after yield as # would be expected. # Additionally, by having this AsyncExitStack here, after the # ExceptionMiddleware, now dependencies can catch handled exceptions, # e.g. HTTPException, to customize the teardown code (e.g. DB session # rollback). Middleware(AsyncExitStackMiddleware), ] ) app = self.router for cls, options in reversed(middleware): app = cls(app=app, **options) return app
def asgi_app(applications, cdn=True, static_dir=None, debug=False, allowed_origins=None, check_origin=None, **starlette_settings): debug = Session.debug = os.environ.get("PYWEBIO_DEBUG", debug) cdn = cdn_validation(cdn, "warn") if cdn is False: cdn = "pywebio_static" routes = webio_routes( applications, cdn=cdn, allowed_origins=allowed_origins, check_origin=check_origin, ) if static_dir: routes.append( Mount("/static", app=StaticFiles(directory=static_dir), name="static")) routes.append( Mount( "/pywebio_static", app=StaticFiles(directory=STATIC_PATH), name="pywebio_static", )) middleware = [Middleware(HeaderMiddleware)] return Starlette(routes=routes, middleware=middleware, debug=debug, **starlette_settings)
class ExampleApp(App): routes = [testroute] middlewares = [Middleware(TestMiddleware)] was_setup = False def setup(self): self.was_setup = True
def starlette_application(unused_udp_port): async def handler_hello(request): return JSONResponse({"hello": "aiodogstatsd"}) async def handler_hello_variable(request): return JSONResponse({"hello": request.path_params["name"]}) async def handler_bad_request(request): return JSONResponse({"hello": "bad"}, status_code=HTTPStatus.BAD_REQUEST) async def handler_internal_server_error(request): raise NotImplementedError() async def handler_unauthorized(request): raise HTTPException(HTTPStatus.UNAUTHORIZED) client = aiodogstatsd.Client(host="0.0.0.0", port=unused_udp_port, constant_tags={"whoami": "batman"}) return Starlette( debug=True, routes=[ Route("/hello", handler_hello), Route("/hello/{name}", handler_hello_variable), Route("/bad_request", handler_bad_request, methods=["POST"]), Route("/internal_server_error", handler_internal_server_error), Route("/unauthorized", handler_unauthorized), ], middleware=[Middleware(StatsDMiddleware, client=client)], on_startup=[client.connect], on_shutdown=[client.close], )
def test_cors_allowed_origin_does_not_leak_between_credentialed_requests( test_client_factory, ): def homepage(request): return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[ Route("/", endpoint=homepage), ], middleware=[ Middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"], ) ], ) client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" assert "access-control-allow-credentials" not in response.headers response = client.get( "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} ) assert response.headers["access-control-allow-origin"] == "https://someplace.org" assert "access-control-allow-credentials" not in response.headers response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" assert "access-control-allow-credentials" not in response.headers
def test_cors_allow_origin_regex_fullmatch(test_client_factory): def homepage(request): return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_headers=["X-Example", "Content-Type"], allow_origin_regex=r"https://.*\.example.org", ) ], ) client = test_client_factory(app) # Test standard response headers = {"Origin": "https://subdomain.example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert ( response.headers["access-control-allow-origin"] == "https://subdomain.example.org" ) assert "access-control-allow-credentials" not in response.headers # Test diallowed standard response headers = {"Origin": "https://subdomain.example.org.hacker.com"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers
def test_session_expires(test_client_factory): app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), ], middleware=[ Middleware(SessionMiddleware, secret_key="example", max_age=-1) ], ) client = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # requests removes expired cookies from response.cookies, we need to # fetch session id from the headers and pass it explicitly expired_cookie_header = response.headers["set-cookie"] expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header) assert expired_session_match is not None expired_session_value = expired_session_match[1] response = client.get("/view_session", cookies={"session": expired_session_value}) assert response.json() == {"session": {}}
def init_with_instana(wrapped, instance, args, kwargs): middleware = kwargs.get('middleware') if middleware is None: kwargs['middleware'] = [Middleware(InstanaASGIMiddleware)] elif isinstance(middleware, list): middleware.append(Middleware(InstanaASGIMiddleware)) exception_handlers = kwargs.get('exception_handlers') if exception_handlers is None: kwargs['exception_handlers'] = dict() if isinstance(kwargs['exception_handlers'], dict): kwargs['exception_handlers'][ HTTPException] = instana_exception_handler return wrapped(*args, **kwargs)
def client() -> starlette.testclient.TestClient: backend = layabauth.starlette.OAuth2IdTokenBackend( jwks_uri="https://test_identity_provider", create_user=lambda token, token_body: SimpleUser(token_body["upn"]), scopes=lambda token, token_body: ["my_scope"], ) application = starlette.applications.Starlette( middleware=[Middleware(AuthenticationMiddleware, backend=backend)]) @application.route("/requires_authentication", methods=["GET"]) @requires("my_scope") async def get_requires_authentication(request): return PlainTextResponse(request.user.display_name) @application.route("/requires_authentication", methods=["POST"]) @requires("my_scope") def post_requires_authentication(request): return PlainTextResponse(request.user.display_name) @application.route("/requires_authentication", methods=["PUT"]) @requires("my_scope") def put_requires_authentication(request): return PlainTextResponse(request.user.display_name) @application.route("/requires_authentication", methods=["DELETE"]) @requires("my_scope") def delete_requires_authentication(request): return PlainTextResponse(request.user.display_name) return starlette.testclient.TestClient(application)
def test_https_redirect_middleware(test_client_factory): def homepage(request): return PlainTextResponse("OK", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(HTTPSRedirectMiddleware)], ) client = test_client_factory(app, base_url="https://testserver") response = client.get("/") assert response.status_code == 200 client = test_client_factory(app) response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:80") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:443") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:123") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/"
def __init__(self, accounts_client): self._accounts_client = accounts_client self._supported_uris = {'acct': self.acct_handler} self._supported_media_types = {'application/jrd+json': JRD} routes = [Route('/webfinger', self.webfinger, methods=['GET'])] middleware = [Middleware(CORSMiddleware, allow_origins=['*'])] self._asgi_app = Starlette(routes=routes, middleware=middleware)
def test_session(test_client_factory): app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), Route("/clear_session", endpoint=clear_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example")], ) client = test_client_factory(app) response = client.get("/view_session") assert response.json() == {"session": {}} response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # check cookie max-age set_cookie = response.headers["set-cookie"] max_age_matches = re.search(r"; Max-Age=([0-9]+);", set_cookie) assert max_age_matches is not None assert int(max_age_matches[1]) == 14 * 24 * 3600 response = client.get("/view_session") assert response.json() == {"session": {"some": "data"}} response = client.post("/clear_session") assert response.json() == {"session": {}} response = client.get("/view_session") assert response.json() == {"session": {}}
def test_middleware(): from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware port = new_port() serve.start(http_port=port, http_middlewares=[ Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) ]) ray.get(block_until_http_ready.remote(f"http://127.0.0.1:{port}/-/routes")) # Snatched several test cases from Starlette # https://github.com/encode/starlette/blob/master/tests/ # middleware/test_cors.py headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "GET", } root = f"http://localhost:{port}" resp = requests.options(root, headers=headers) assert resp.headers["access-control-allow-origin"] == "*" resp = requests.get(f"{root}/-/routes", headers=headers) assert resp.headers["access-control-allow-origin"] == "*" ray.shutdown()
def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( test_client_factory, ): def homepage(request): return # pragma: no cover app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["POST"], allow_credentials=True, ) ], ) client = test_client_factory(app) # Test pre-flight response headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "POST", } response = client.options( "/", headers=headers, ) assert response.status_code == 200 assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-credentials"] == "true" assert response.headers["vary"] == "Origin"
def __init__(self, port=None, tag=None): # log if tag is None: self.log = logging.getLogger(__name__) else: self.log = logging.getLogger("%s.%s" % (tag, RAY.TAG)) ray.init(address="auto") nodes_info = ray.nodes() try: self.client = serve.start(http_options={ "location": "EveryNode", "host": "0.0.0.0", "port": port, "middlewares": [ Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) ] }, detached=True) self.log.info( "Ray serve initialized, node number: {} \n Nodes Info: {}". format(len(nodes_info), nodes_info)) except RayServeException: self.client = serve.connect() self.log.info( "Connected existing Ray serve, node number: {} \n Nodes Info: {}" .format(len(nodes_info), nodes_info))
def setup_api(title: str, description: str, version: str, cors_origins: List[str], routers: List): middleware = [ Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(), plugins.CorrelationIdPlugin())) ] app = FastAPI(title=title, description=description, version=version, middleware=middleware) if cors_origins: _LOGGER.debug( f'Adding CORS middleware for origins {", ".join(cors_origins)}') app.add_middleware( CORSMiddleware, allow_origins=cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # FastAPIInstrumentor.instrument_app(app) # PrometheusInstrumentator().instrument(app).expose(app) for router in routers: app.include_router(router.router) return app