async def config_routes(self): if isinstance(self.things, MultipleThings): for idx, thing in await self.things.get_things(): await thing.set_href_prefix(f"{self.base_path}/{idx}") base_route = [ Route(f"{self.base_path}", ThingsHandler), ] routes = [ Route("/{thing_id:str}", ThingHandler), WebSocketRoute("/{thing_id:str}", WsThingHandler), Route("/{thing_id:str}/properties", PropertiesHandler), Route("/{thing_id:str}/properties/{property_name:str}", PropertyHandler), Route("/{thing_id:str}/actions", ActionsHandler), Route("/{thing_id:str}/actions/{action_name:str}", ActionHandler), Route("/{thing_id:str}/actions/{action_name:str}/{action_id}", ActionHandler), Route("/{thing_id:str}/events", EventHandler), Route("/{thing_id:str}/events/{event_name:str}", EventHandler), ] else: thing = await self.things.get_thing() await thing.set_href_prefix(self.base_path) base_route = [ Route(f"{self.base_path}", ThingHandler), WebSocketRoute(f"{self.base_path}", WsThingHandler), ] routes = [ Route("/properties", PropertiesHandler), Route("/properties/{property_name:str}", PropertyHandler), Route("/actions", ActionsHandler), Route("/actions/{action_name:str}", ActionHandler), Route("/actions/{action_name:str}/{action_id:str}", ActionIDHandler), Route("/events", EventsHandler), Route("/events/{event_name:str}", EventHandler), ] # should additional_routes also have prefix? if isinstance(self.additional_routes, list): routes = self.additional_routes + routes if self.base_path: routes = base_route + [ # Route('/', homepage), Mount(f"{self.base_path}", routes=routes), ] return routes
def routes(self): return [ Route('/', self.index), Route('/status', endpoint=self.status), Mount('/static', app=StaticFiles(directory=self.static)), WebSocketRoute('/ws', self.ws_endpoint), ]
def build_routes(self): return ( [ Route("/dagit_info", self.dagit_info_endpoint), Route( "/graphql", self.graphql_http_endpoint, name="graphql-http", methods=["GET", "POST"], ), WebSocketRoute( "/graphql", self.graphql_ws_endpoint, name="graphql-ws", ), ] + self.build_static_routes() + [ # download file endpoints Route( "/download/{run_id:str}/{step_key:str}/{file_type:str}", self.download_compute_logs_endpoint, ), Route( "/dagit/notebook", self.download_notebook, ), Route( "/download_debug/{run_id:str}", self.download_debug_file_endpoint, ), Route("/{path:path}", self.index_html_endpoint), Route("/", self.index_html_endpoint), ] )
def routes(ws: WebsocketManager) -> List[Any]: return [ WebSocketRoute( '/{room_id}', lambda websocket: ws.connection_handler( StarletteWebsocketClient(websocket)), ) ]
def create_app(): return Starlette(routes=[ Route(endpoint=index, path='/', methods=['GET']), WebSocketRoute(endpoint=websocket_route, path='/game') ], middleware=[ Middleware(SessionMiddleware, secret_key='testingsecretkey') ])
def __init__(self, name, serial=None, debug=False): self.name = name self.serial = serial self.templates = Jinja2Templates(directory=templates_dir) self.screenshot_queues = set() routes = [ Route("/", self.index_route), WebSocketRoute("/ws/screenshots", self.screenshot_endpoint), WebSocketRoute("/ws/control", self.control_endpoint), Mount("/static", StaticFiles(directory=static_dir), name="static"), ] Starlette.__init__(self, debug=debug, routes=routes) async def startup_handler(): asyncio.ensure_future(self.screenshot_task()) self.add_event_handler("startup", startup_handler)
def __init__( self, *, engine: Engine = None, sdl: str = None, graphiql: typing.Union[bool, GraphiQL] = True, path: str = "/", subscriptions: typing.Union[bool, Subscriptions] = None, context: dict = None, schema_name: str = "default", ): if engine is None: assert sdl, "`sdl` expected if `engine` not given" engine = Engine(sdl=sdl, schema_name=schema_name) assert engine, "`engine` expected if `sdl` not given" self.engine = engine if context is None: context = {} if graphiql is True: graphiql = GraphiQL() if subscriptions is True: subscriptions = Subscriptions(path="/subscriptions") routes = [] if graphiql and graphiql.path is not None: routes.append(Route(path=graphiql.path, endpoint=GraphiQLEndpoint)) graphql_route = Route(path=path, endpoint=GraphQLEndpoint) routes.append(graphql_route) if subscriptions is not None: subscription_route = WebSocketRoute(path=subscriptions.path, endpoint=SubscriptionEndpoint) routes.append(subscription_route) config = GraphQLConfig( engine=self.engine, context=context, graphiql=graphiql, path=graphql_route.path, subscriptions=subscriptions, ) router = Router(routes=routes) self.app = GraphQLMiddleware(router, config=config) self.lifespan = Lifespan(on_startup=self.startup) self._started_up = False
def build_tasks_api(manager): routes = [Route("/", endpoint=tasks, methods=["GET"])] if manager._task_worker_enabled: task_socket_route = WebSocketRoute("/{task}", TaskEndpoint, name="task_manager_api") routes.append(task_socket_route) return Starlette( routes=routes, on_startup=[manager.broadcast.connect], on_shutdown=[manager.broadcast.disconnect], )
def create_app( process_context: WorkspaceProcessContext, debug: bool, app_path_prefix: str, ): graphql_schema = create_schema() base_dir = path.dirname(__file__) bound_index_endpoint = partial(index_endpoint, base_dir, app_path_prefix) return Starlette( debug=debug, routes=[ Route("/dagit_info", dagit_info_endpoint), Route( "/graphql", partial(graphql_http_endpoint, graphql_schema, process_context, app_path_prefix), name="graphql-http", methods=["GET", "POST"], ), WebSocketRoute( "/graphql", partial(graphql_ws_endpoint, graphql_schema, process_context), name="graphql-ws", ), # static resources addressed at /static/ Mount( "/static", StaticFiles( directory=path.join(base_dir, "./webapp/build/static")), name="static", ), # static resources addressed at /vendor/ Mount( "/vendor", StaticFiles( directory=path.join(base_dir, "./webapp/build/vendor")), name="vendor", ), # specific static resources addressed at / *create_root_static_endpoints(base_dir), # download file endpoints Route( "/download_debug/{run_id:str}", partial(download_debug_file_endpoint, process_context), ), Route("/{path:path}", bound_index_endpoint), Route("/", bound_index_endpoint), ], )
def __init__( self, schema: GraphQLSchema = None, *, type_defs: str = None, schema_file: str = None, federation: bool = False, playground: bool = True, debug: bool = False, routes: typing.List[BaseRoute] = None, path: str = '/', subscription_path: str = '/', subscription_authenticate: typing.Awaitable = None, error_formater: ERROR_FORMATER = None, graphql_middleware: typing.Union[tuple, list, typing.Dict[str, list]] = None, graphql_middleware_exclude: typing.List[str] = None, context_builder: typing.Callable = None, **kwargs, ): routes = routes or [] if schema: self.schema = schema elif type_defs: self.schema = make_schema(type_defs, federation=federation) elif schema_file: self.schema = make_schema_from_file(schema_file, federation=federation) else: raise Exception('Must provide type def string or file.') routes.extend([ Route( path, ASGIApp( self.schema, debug=debug, playground=playground, error_formater=error_formater, graphql_middleware=graphql_middleware, graphql_middleware_exclude=graphql_middleware_exclude, context_builder=context_builder, ), ), WebSocketRoute( subscription_path, Subscription(self.schema, authenticate=subscription_authenticate), ), ]) super().__init__(debug=debug, routes=routes, **kwargs)
def test_app_add_websocket_route(test_client_factory): async def websocket_endpoint(session): await session.accept() await session.send_text("Hello, world!") await session.close() app = Starlette(routes=[ WebSocketRoute("/ws", endpoint=websocket_endpoint), ]) client = test_client_factory(app) with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!"
def test_routes(): assert app.routes == [ Route("/func", endpoint=func_homepage, methods=["GET"]), Route("/async", endpoint=async_homepage, methods=["GET"]), Route("/class", endpoint=Homepage), Mount( "/users", app=Router(routes=[ Route("/", endpoint=all_users_page), Route("/{username}", endpoint=user_page), ]), ), Route("/500", endpoint=runtime_error, methods=["GET"]), WebSocketRoute("/ws", endpoint=websocket_endpoint), ]
def build_routes(self): routes = ( [ Route("/dagit_info", self.dagit_info_endpoint), Route( "/graphql", self.graphql_http_endpoint, name="graphql-http", methods=["GET", "POST"], ), WebSocketRoute( "/graphql", self.graphql_ws_endpoint, name="graphql-ws", ), ] + self.build_static_routes() + [ # download file endpoints Route( "/download/{run_id:str}/{step_key:str}/{file_type:str}", self.download_compute_logs_endpoint, ), Route( "/dagit/notebook", self.download_notebook, ), Route( "/download_debug/{run_id:str}", self.download_debug_file_endpoint, ), Route("/{path:path}", self.index_html_endpoint), Route("/", self.index_html_endpoint), ] ) if self._app_path_prefix: def _redirect(_): return RedirectResponse(url=self._app_path_prefix) return [ Mount(self._app_path_prefix, routes=routes), Route("/", _redirect), ] else: return routes
def ws_magic_route( self, path: str, endpoint: Callable, *, name: str = None, ) -> WebSocketRoute: """Returns an instance of a starlette WebSocketRoute The callable endpoint is bound to the container so dependencies can be auto injected. All other arguments are passed on to starlette. :param path: :param endpoint: :param name: :return: """ wrapped = self.wrapped_endpoint_factory(endpoint, self._container.magic_partial) return WebSocketRoute(path, wrapped, name=name)
def on_api_initialized_v2(self, api): def import_pipelines(req): return JSONResponse(list(self.pipelines.keys())) # This breaks silently if we can't connect self.broadcast = Broadcast("redis://broker:6379") log.debug("Setting up pipelines API") app = Starlette( routes=[ Route("/pipelines", import_pipelines), WebSocketRoute("/pipeline/{pipeline}", ImporterEndpoint, name="import_tracker_api"), ], on_startup=[self.connect], on_shutdown=[self.broadcast.disconnect], ) api.mount("/import-tracker", app)
def run_server(host, port, no_open, shutdown_timeout, home_url): contrib_dir = get_core_package_dir("contrib-piohome") if not os.path.isdir(contrib_dir): raise PlatformioException("Invalid path to PIO Home Contrib") ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout) ws_rpc_factory.addHandler(AccountRPC(), namespace="account") ws_rpc_factory.addHandler(AppRPC(), namespace="app") ws_rpc_factory.addHandler(IDERPC(), namespace="ide") ws_rpc_factory.addHandler(MiscRPC(), namespace="misc") ws_rpc_factory.addHandler(OSRPC(), namespace="os") ws_rpc_factory.addHandler(PIOCoreRPC(), namespace="core") ws_rpc_factory.addHandler(ProjectRPC(), namespace="project") path = urlparse(home_url).path routes = [ WebSocketRoute(path + "wsrpc", ws_rpc_factory, name="wsrpc"), Route(path + "__shutdown__", shutdown_server, methods=["POST"]), Mount(path, StaticFiles(directory=contrib_dir, html=True), name="static"), ] if path != "/": routes.append(Route("/", protected_page)) uvicorn.run( Starlette( middleware=[Middleware(ShutdownMiddleware)], routes=routes, on_startup=[ lambda: click.echo( "PIO Home has been started. Press Ctrl+C to shutdown."), lambda: None if no_open else click.launch(home_url), ], ), host=host, port=port, log_level="warning", )
def reload_endpoint(watch_path: str): async def watch_reload(prompt_reload): async for _ in awatch(watch_path, watcher_cls=FoxgloveWatcher): await prompt_reload() class ReloadWs(WebSocketEndpoint): encoding = 'text' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._watch_task = asyncio.create_task( watch_reload(self.prompt_reload)) self.ws = None async def prompt_reload(self): if self.ws: logger.debug('prompting reload') await self.ws.send_text('reload') async def on_connect(self, websocket): logger.debug('reload websocket connecting') await websocket.accept() self.ws = websocket async def on_disconnect(self, websocket, close_code): logger.debug('reload websocket disconnecting') self._watch_task.cancel() try: await self._watch_task except asyncio.CancelledError: logger.debug('file watcher cancelled') return [ WebSocketRoute('/.devtools/reload-ws/', ReloadWs, name='devtools-reload'), Route('/.devtools/up/', devtools_up, methods=['GET']), ]
def add_ws_endpoint(self, *args, **kwargs): """ Add a websocket endpoint to the server Args: *args: either a single mock ws endpoint, or parameters forwarded to ws_endpoint construct one **kwargs: forwarded to ws_endpoint to construct an endpoint Returns: the websocket endpoint added to the server """ self._raise_from_pending() if len(args) == 1 and not kwargs: ep, = args else: ep = ws_endpoint(*args, **kwargs) if ep.owner is not None: raise RuntimeError('an endpoint cannot be added twice') with self._route_lock: self._app.routes.append( WebSocketRoute(ep.rule_string, ep.endpoint, name=ep.__name__)) ep.owner = self return ep
def __init__( self, aggregation: str, host: str = None, http_port: int = None, exit_round: int = 1, participants_count: int = 1, ws_size: int = 10 * 1024 * 1024): if not host: host = Context.get_parameters("AGG_BIND_IP", get_host_ip()) if not http_port: http_port = int(Context.get_parameters("AGG_BIND_PORT", 7363)) super( AggregationServer, self).__init__( servername=aggregation, host=host, http_port=http_port, ws_size=ws_size) self.aggregation = aggregation self.participants_count = participants_count self.exit_round = max(int(exit_round), 1) self.app = FastAPI( routes=[ APIRoute( f"/{aggregation}", self.client_info, response_class=JSONResponse, ), WebSocketRoute( f"/{aggregation}", BroadcastWs ) ], ) self.app.shutdown = False
return is_basic_mutation(type.of_type) return False def validate_schema(schema): for name, mutation in schema.mutation_type.fields.items(): if name == "_unused": continue if not is_non_null_basic_mutation(mutation.type): raise ValueError(f"{name} Must implement IMutationResponse!") validate_schema(schema) graphql_route = GraphQL(schema, middleware=[auth_middleware]) app = Starlette( routes=[ WebSocketRoute("/ws/shell/{uuid:str}", handle_shell), Route("/api", graphql_route), Route("/api/", graphql_route), WebSocketRoute("/api", GraphQL(schema=schema)), Mount("/", StaticFilesFallback(directory="static", html=True)), ], on_startup=[on_startup], on_shutdown=[on_shutdown], middleware=[ Middleware(RawContextMiddleware), Middleware(RequestCacheMiddleware) ], )
def test_standalone_ws_route_does_not_match(): app = WebSocketRoute("/", ws_helloworld) client = TestClient(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/invalid"): pass # pragma: nocover
def test_standalone_ws_route_matches(): app = WebSocketRoute("/", ws_helloworld) client = TestClient(app) with client.websocket_connect("/") as websocket: text = websocket.receive_text() assert text == "Hello, world!"
url = request.url_for("http_endpoint") return Response(f"URL: {url}", media_type="text/plain") class WebSocketEndpoint: async def __call__(self, scope, receive, send): websocket = WebSocket(scope=scope, receive=receive, send=send) await websocket.accept() await websocket.send_json( {"URL": str(websocket.url_for("websocket_endpoint"))}) await websocket.close() mixed_protocol_app = Router(routes=[ Route("/", endpoint=http_endpoint), WebSocketRoute( "/", endpoint=WebSocketEndpoint(), name="websocket_endpoint"), ]) def test_protocol_switch(): client = TestClient(mixed_protocol_app) response = client.get("/") assert response.status_code == 200 assert response.text == "URL: http://testserver/" with client.websocket_connect("/") as session: assert session.receive_json() == {"URL": "ws://testserver/"} with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/404"):
async def background_task(message): # send mail, hash a pw, time consuming stuff like that await sleep(5) print(message) # starlette functional endpoint async def test(request): task = BackgroundTask(background_task, message="testing background task!") return PlainTextResponse( "Testing Starlette functional endpoint with a bg task!", background=task) # fastapi routes routes = [ Mount("/sanic-api", sanic_app), WebSocketRoute("/ws", websocket_endpoint), Route('/test', endpoint=test), Mount("/", starlette_static_files), ] app = FastAPI(routes=routes) # uncomment this to enable the trusted host middleware #app.add_middleware( # TrustedHostMiddleware, # allowed_hosts=['example.com', '*.example.com'] #)
pass # pragma: no cover def schema(request): return schemas.OpenAPIResponse(request=request) subapp = Starlette( routes=[ Route("/subapp-endpoint", endpoint=subapp_endpoint), ] ) app = Starlette( routes=[ WebSocketRoute("/ws", endpoint=ws), Route("/users", endpoint=list_users, methods=["GET", "HEAD"]), Route("/users", endpoint=create_user, methods=["POST"]), Route("/orgs", endpoint=OrganisationsEndpoint), Route("/regular-docstring-and-schema", endpoint=regular_docstring_and_schema), Route("/regular-docstring", endpoint=regular_docstring), Route("/no-docstring", endpoint=no_docstring), Route("/schema", endpoint=schema, methods=["GET"], include_in_schema=False), Mount("/subapp", subapp), ] ) def test_schema_generation(): schema = schemas.get_schema(routes=app.routes) assert schema == {
</body> <script type="text/javascript"> const ws = new WebSocket(`ws://${window.location.host}/ws`); ws.onmessage = (e) => { const data = e.data; console.log(data); }; ws.onclose = (e) => { console.error(e) }; </script> </html> """ return HTMLResponse(html) app = Starlette(debug=True, routes=[Route("/", index), WebSocketRoute("/ws", hello)]) def test_app(): from starlette.testclient import TestClient c = TestClient(app) with c.websocket_connect("/ws") as ws: data = ws.receive_text() print(data) if __name__ == "__main__": test_app()
from starlette.routing import Route, WebSocketRoute async def homepage(request): return FileResponse('./tests/index.html') class MyElement(CustomElement): async def setup(self): self.test = "foo" async def click(self): self.test = "bar" async def render(self): if self.test == "bar": return span[i["baz"], "foobar", b["barbaz"]] return h1(on_click=self.click)[self.test] async def on_accept(websocket: HTMLOverTheAirWebSocket): await websocket.dom.render(MyElement) routes = [ Route("/", endpoint=homepage), WebSocketRoute("/ws", endpoint=HTMLOverTheAirEndpoint(on_accept=on_accept)) ] app = Starlette(routes=routes)
class HandledExcAfterResponse: def __init__(self, scope): pass async def __call__(self, receive, send): response = PlainTextResponse("OK", status_code=200) await response(receive, send) raise HTTPException(status_code=406) router = Router(routes=[ Route("/runtime_error", endpoint=raise_runtime_error), Route("/not_acceptable", endpoint=not_acceptable), Route("/not_modified", endpoint=not_modified), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), ]) app = ExceptionMiddleware(router) client = TestClient(app) def test_server_error(): with pytest.raises(RuntimeError): response = client.get("/runtime_error") allow_500_client = TestClient(app, raise_server_exceptions=False) response = allow_500_client.get("/runtime_error") assert response.status_code == 500 assert response.text == "Internal Server Error"
def __init__( self, *, engine: Engine = None, sdl: str = None, graphiql: typing.Union[None, bool, GraphiQL] = True, path: str = "/", subscriptions: typing.Union[bool, Subscriptions] = None, context: dict = None, schema_name: str = "default", ) -> None: if engine is None: assert sdl, "`sdl` expected if `engine` not given" engine = Engine(sdl=sdl, schema_name=schema_name) assert engine, "`engine` expected if `sdl` not given" self.engine = engine if context is None: context = {} if graphiql is True: graphiql = GraphiQL() elif not graphiql: graphiql = None assert graphiql is None or isinstance(graphiql, GraphiQL) if subscriptions is True: subscriptions = Subscriptions(path="/subscriptions") elif not subscriptions: subscriptions = None assert subscriptions is None or isinstance(subscriptions, Subscriptions) routes: typing.List[BaseRoute] = [] if graphiql and graphiql.path is not None: routes.append(Route(graphiql.path, GraphiQLEndpoint)) routes.append(Route(path, GraphQLEndpoint)) if subscriptions is not None: routes.append( WebSocketRoute(subscriptions.path, SubscriptionEndpoint)) self.router = Router(routes=routes, on_startup=[self.startup]) config = GraphQLConfig( engine=self.engine, context=context, graphiql=graphiql, path=path, subscriptions=subscriptions, ) self.app = GraphQLMiddleware(self.router, config=config) self._started_up = False
from starlette.applications import Starlette from starlette.routing import Route, WebSocketRoute from starlette.middleware import Middleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.gzip import GZipMiddleware from sqlalchemy import update import config from views import Server from models import Song ROUTES = [ WebSocketRoute('/ws', Server) ] MIDDLEWARES = [ Middleware(CORSMiddleware, allow_origins=config.CORS_ORIGIN), Middleware(GZipMiddleware), ] if not config.DEBUG: MIDDLEWARES.append(Middleware(HTTPSRedirectMiddleware)) MIDDLEWARES.append(Middleware(TrustedHostMiddleware, allowed_hosts=config.ALLOWED_HOSTS)) app = Starlette(debug=config.DEBUG, routes=ROUTES) app.state.config = config