예제 #1
0
    async def config_routes(self):
        if isinstance(self.things, MultipleThings):
            for idx, thing in await self.things.get_things():
                await thing.set_href_prefix(f"{self.base_path}/{idx}")

            base_route = [
                Route(f"{self.base_path}", ThingsHandler),
            ]
            routes = [
                Route("/{thing_id:str}", ThingHandler),
                WebSocketRoute("/{thing_id:str}", WsThingHandler),
                Route("/{thing_id:str}/properties", PropertiesHandler),
                Route("/{thing_id:str}/properties/{property_name:str}",
                      PropertyHandler),
                Route("/{thing_id:str}/actions", ActionsHandler),
                Route("/{thing_id:str}/actions/{action_name:str}",
                      ActionHandler),
                Route("/{thing_id:str}/actions/{action_name:str}/{action_id}",
                      ActionHandler),
                Route("/{thing_id:str}/events", EventHandler),
                Route("/{thing_id:str}/events/{event_name:str}", EventHandler),
            ]
        else:
            thing = await self.things.get_thing()
            await thing.set_href_prefix(self.base_path)

            base_route = [
                Route(f"{self.base_path}", ThingHandler),
                WebSocketRoute(f"{self.base_path}", WsThingHandler),
            ]
            routes = [
                Route("/properties", PropertiesHandler),
                Route("/properties/{property_name:str}", PropertyHandler),
                Route("/actions", ActionsHandler),
                Route("/actions/{action_name:str}", ActionHandler),
                Route("/actions/{action_name:str}/{action_id:str}",
                      ActionIDHandler),
                Route("/events", EventsHandler),
                Route("/events/{event_name:str}", EventHandler),
            ]

        # should additional_routes also have prefix?
        if isinstance(self.additional_routes, list):
            routes = self.additional_routes + routes

        if self.base_path:
            routes = base_route + [
                # Route('/', homepage),
                Mount(f"{self.base_path}", routes=routes),
            ]

        return routes
예제 #2
0
 def routes(self):
     return [
         Route('/', self.index),
         Route('/status', endpoint=self.status),
         Mount('/static', app=StaticFiles(directory=self.static)),
         WebSocketRoute('/ws', self.ws_endpoint),
     ]
예제 #3
0
 def build_routes(self):
     return (
         [
             Route("/dagit_info", self.dagit_info_endpoint),
             Route(
                 "/graphql",
                 self.graphql_http_endpoint,
                 name="graphql-http",
                 methods=["GET", "POST"],
             ),
             WebSocketRoute(
                 "/graphql",
                 self.graphql_ws_endpoint,
                 name="graphql-ws",
             ),
         ]
         + self.build_static_routes()
         + [
             # download file endpoints
             Route(
                 "/download/{run_id:str}/{step_key:str}/{file_type:str}",
                 self.download_compute_logs_endpoint,
             ),
             Route(
                 "/dagit/notebook",
                 self.download_notebook,
             ),
             Route(
                 "/download_debug/{run_id:str}",
                 self.download_debug_file_endpoint,
             ),
             Route("/{path:path}", self.index_html_endpoint),
             Route("/", self.index_html_endpoint),
         ]
     )
예제 #4
0
def routes(ws: WebsocketManager) -> List[Any]:
    return [
        WebSocketRoute(
            '/{room_id}',
            lambda websocket: ws.connection_handler(
                StarletteWebsocketClient(websocket)),
        )
    ]
예제 #5
0
def create_app():
    return Starlette(routes=[
        Route(endpoint=index, path='/', methods=['GET']),
        WebSocketRoute(endpoint=websocket_route, path='/game')
    ],
                     middleware=[
                         Middleware(SessionMiddleware,
                                    secret_key='testingsecretkey')
                     ])
예제 #6
0
파일: app.py 프로젝트: storborg/remotedroid
    def __init__(self, name, serial=None, debug=False):
        self.name = name
        self.serial = serial
        self.templates = Jinja2Templates(directory=templates_dir)
        self.screenshot_queues = set()

        routes = [
            Route("/", self.index_route),
            WebSocketRoute("/ws/screenshots", self.screenshot_endpoint),
            WebSocketRoute("/ws/control", self.control_endpoint),
            Mount("/static", StaticFiles(directory=static_dir), name="static"),
        ]

        Starlette.__init__(self, debug=debug, routes=routes)

        async def startup_handler():
            asyncio.ensure_future(self.screenshot_task())

        self.add_event_handler("startup", startup_handler)
예제 #7
0
    def __init__(
        self,
        *,
        engine: Engine = None,
        sdl: str = None,
        graphiql: typing.Union[bool, GraphiQL] = True,
        path: str = "/",
        subscriptions: typing.Union[bool, Subscriptions] = None,
        context: dict = None,
        schema_name: str = "default",
    ):
        if engine is None:
            assert sdl, "`sdl` expected if `engine` not given"
            engine = Engine(sdl=sdl, schema_name=schema_name)

        assert engine, "`engine` expected if `sdl` not given"

        self.engine = engine

        if context is None:
            context = {}

        if graphiql is True:
            graphiql = GraphiQL()

        if subscriptions is True:
            subscriptions = Subscriptions(path="/subscriptions")

        routes = []

        if graphiql and graphiql.path is not None:
            routes.append(Route(path=graphiql.path, endpoint=GraphiQLEndpoint))

        graphql_route = Route(path=path, endpoint=GraphQLEndpoint)
        routes.append(graphql_route)

        if subscriptions is not None:
            subscription_route = WebSocketRoute(path=subscriptions.path,
                                                endpoint=SubscriptionEndpoint)
            routes.append(subscription_route)

        config = GraphQLConfig(
            engine=self.engine,
            context=context,
            graphiql=graphiql,
            path=graphql_route.path,
            subscriptions=subscriptions,
        )

        router = Router(routes=routes)
        self.app = GraphQLMiddleware(router, config=config)
        self.lifespan = Lifespan(on_startup=self.startup)

        self._started_up = False
예제 #8
0
def build_tasks_api(manager):
    routes = [Route("/", endpoint=tasks, methods=["GET"])]
    if manager._task_worker_enabled:
        task_socket_route = WebSocketRoute("/{task}",
                                           TaskEndpoint,
                                           name="task_manager_api")
        routes.append(task_socket_route)

    return Starlette(
        routes=routes,
        on_startup=[manager.broadcast.connect],
        on_shutdown=[manager.broadcast.disconnect],
    )
예제 #9
0
def create_app(
    process_context: WorkspaceProcessContext,
    debug: bool,
    app_path_prefix: str,
):
    graphql_schema = create_schema()
    base_dir = path.dirname(__file__)

    bound_index_endpoint = partial(index_endpoint, base_dir, app_path_prefix)

    return Starlette(
        debug=debug,
        routes=[
            Route("/dagit_info", dagit_info_endpoint),
            Route(
                "/graphql",
                partial(graphql_http_endpoint, graphql_schema, process_context,
                        app_path_prefix),
                name="graphql-http",
                methods=["GET", "POST"],
            ),
            WebSocketRoute(
                "/graphql",
                partial(graphql_ws_endpoint, graphql_schema, process_context),
                name="graphql-ws",
            ),
            # static resources addressed at /static/
            Mount(
                "/static",
                StaticFiles(
                    directory=path.join(base_dir, "./webapp/build/static")),
                name="static",
            ),
            # static resources addressed at /vendor/
            Mount(
                "/vendor",
                StaticFiles(
                    directory=path.join(base_dir, "./webapp/build/vendor")),
                name="vendor",
            ),
            # specific static resources addressed at /
            *create_root_static_endpoints(base_dir),
            # download file endpoints
            Route(
                "/download_debug/{run_id:str}",
                partial(download_debug_file_endpoint, process_context),
            ),
            Route("/{path:path}", bound_index_endpoint),
            Route("/", bound_index_endpoint),
        ],
    )
예제 #10
0
    def __init__(
        self,
        schema: GraphQLSchema = None,
        *,
        type_defs: str = None,
        schema_file: str = None,
        federation: bool = False,
        playground: bool = True,
        debug: bool = False,
        routes: typing.List[BaseRoute] = None,
        path: str = '/',
        subscription_path: str = '/',
        subscription_authenticate: typing.Awaitable = None,
        error_formater: ERROR_FORMATER = None,
        graphql_middleware: typing.Union[tuple, list,
                                         typing.Dict[str, list]] = None,
        graphql_middleware_exclude: typing.List[str] = None,
        context_builder: typing.Callable = None,
        **kwargs,
    ):
        routes = routes or []
        if schema:
            self.schema = schema
        elif type_defs:
            self.schema = make_schema(type_defs, federation=federation)
        elif schema_file:
            self.schema = make_schema_from_file(schema_file,
                                                federation=federation)
        else:
            raise Exception('Must provide type def string or file.')

        routes.extend([
            Route(
                path,
                ASGIApp(
                    self.schema,
                    debug=debug,
                    playground=playground,
                    error_formater=error_formater,
                    graphql_middleware=graphql_middleware,
                    graphql_middleware_exclude=graphql_middleware_exclude,
                    context_builder=context_builder,
                ),
            ),
            WebSocketRoute(
                subscription_path,
                Subscription(self.schema,
                             authenticate=subscription_authenticate),
            ),
        ])
        super().__init__(debug=debug, routes=routes, **kwargs)
예제 #11
0
def test_app_add_websocket_route(test_client_factory):
    async def websocket_endpoint(session):
        await session.accept()
        await session.send_text("Hello, world!")
        await session.close()

    app = Starlette(routes=[
        WebSocketRoute("/ws", endpoint=websocket_endpoint),
    ])
    client = test_client_factory(app)

    with client.websocket_connect("/ws") as session:
        text = session.receive_text()
        assert text == "Hello, world!"
예제 #12
0
def test_routes():
    assert app.routes == [
        Route("/func", endpoint=func_homepage, methods=["GET"]),
        Route("/async", endpoint=async_homepage, methods=["GET"]),
        Route("/class", endpoint=Homepage),
        Mount(
            "/users",
            app=Router(routes=[
                Route("/", endpoint=all_users_page),
                Route("/{username}", endpoint=user_page),
            ]),
        ),
        Route("/500", endpoint=runtime_error, methods=["GET"]),
        WebSocketRoute("/ws", endpoint=websocket_endpoint),
    ]
예제 #13
0
    def build_routes(self):
        routes = (
            [
                Route("/dagit_info", self.dagit_info_endpoint),
                Route(
                    "/graphql",
                    self.graphql_http_endpoint,
                    name="graphql-http",
                    methods=["GET", "POST"],
                ),
                WebSocketRoute(
                    "/graphql",
                    self.graphql_ws_endpoint,
                    name="graphql-ws",
                ),
            ]
            + self.build_static_routes()
            + [
                # download file endpoints
                Route(
                    "/download/{run_id:str}/{step_key:str}/{file_type:str}",
                    self.download_compute_logs_endpoint,
                ),
                Route(
                    "/dagit/notebook",
                    self.download_notebook,
                ),
                Route(
                    "/download_debug/{run_id:str}",
                    self.download_debug_file_endpoint,
                ),
                Route("/{path:path}", self.index_html_endpoint),
                Route("/", self.index_html_endpoint),
            ]
        )

        if self._app_path_prefix:

            def _redirect(_):
                return RedirectResponse(url=self._app_path_prefix)

            return [
                Mount(self._app_path_prefix, routes=routes),
                Route("/", _redirect),
            ]
        else:
            return routes
예제 #14
0
    def ws_magic_route(
        self,
        path: str,
        endpoint: Callable,
        *,
        name: str = None,
    ) -> WebSocketRoute:
        """Returns an instance of a starlette WebSocketRoute
        The callable endpoint is bound to the container so dependencies can be
        auto injected. All other arguments are passed on to starlette.
        :param path:
        :param endpoint:
        :param name:
        :return:
        """
        wrapped = self.wrapped_endpoint_factory(endpoint,
                                                self._container.magic_partial)

        return WebSocketRoute(path, wrapped, name=name)
예제 #15
0
    def on_api_initialized_v2(self, api):
        def import_pipelines(req):
            return JSONResponse(list(self.pipelines.keys()))

        # This breaks silently if we can't connect
        self.broadcast = Broadcast("redis://broker:6379")
        log.debug("Setting up pipelines API")

        app = Starlette(
            routes=[
                Route("/pipelines", import_pipelines),
                WebSocketRoute("/pipeline/{pipeline}",
                               ImporterEndpoint,
                               name="import_tracker_api"),
            ],
            on_startup=[self.connect],
            on_shutdown=[self.broadcast.disconnect],
        )

        api.mount("/import-tracker", app)
예제 #16
0
def run_server(host, port, no_open, shutdown_timeout, home_url):
    contrib_dir = get_core_package_dir("contrib-piohome")
    if not os.path.isdir(contrib_dir):
        raise PlatformioException("Invalid path to PIO Home Contrib")

    ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout)
    ws_rpc_factory.addHandler(AccountRPC(), namespace="account")
    ws_rpc_factory.addHandler(AppRPC(), namespace="app")
    ws_rpc_factory.addHandler(IDERPC(), namespace="ide")
    ws_rpc_factory.addHandler(MiscRPC(), namespace="misc")
    ws_rpc_factory.addHandler(OSRPC(), namespace="os")
    ws_rpc_factory.addHandler(PIOCoreRPC(), namespace="core")
    ws_rpc_factory.addHandler(ProjectRPC(), namespace="project")

    path = urlparse(home_url).path
    routes = [
        WebSocketRoute(path + "wsrpc", ws_rpc_factory, name="wsrpc"),
        Route(path + "__shutdown__", shutdown_server, methods=["POST"]),
        Mount(path,
              StaticFiles(directory=contrib_dir, html=True),
              name="static"),
    ]
    if path != "/":
        routes.append(Route("/", protected_page))

    uvicorn.run(
        Starlette(
            middleware=[Middleware(ShutdownMiddleware)],
            routes=routes,
            on_startup=[
                lambda: click.echo(
                    "PIO Home has been started. Press Ctrl+C to shutdown."),
                lambda: None if no_open else click.launch(home_url),
            ],
        ),
        host=host,
        port=port,
        log_level="warning",
    )
예제 #17
0
def reload_endpoint(watch_path: str):
    async def watch_reload(prompt_reload):
        async for _ in awatch(watch_path, watcher_cls=FoxgloveWatcher):
            await prompt_reload()

    class ReloadWs(WebSocketEndpoint):
        encoding = 'text'

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self._watch_task = asyncio.create_task(
                watch_reload(self.prompt_reload))
            self.ws = None

        async def prompt_reload(self):
            if self.ws:
                logger.debug('prompting reload')
                await self.ws.send_text('reload')

        async def on_connect(self, websocket):
            logger.debug('reload websocket connecting')
            await websocket.accept()
            self.ws = websocket

        async def on_disconnect(self, websocket, close_code):
            logger.debug('reload websocket disconnecting')
            self._watch_task.cancel()
            try:
                await self._watch_task
            except asyncio.CancelledError:
                logger.debug('file watcher cancelled')

    return [
        WebSocketRoute('/.devtools/reload-ws/',
                       ReloadWs,
                       name='devtools-reload'),
        Route('/.devtools/up/', devtools_up, methods=['GET']),
    ]
예제 #18
0
    def add_ws_endpoint(self, *args, **kwargs):
        """
        Add a websocket endpoint to the server
        Args:
            *args: either a single mock ws endpoint, or parameters forwarded to ws_endpoint construct one
            **kwargs: forwarded to ws_endpoint to construct an endpoint

        Returns:
            the websocket endpoint added to the server
        """
        self._raise_from_pending()
        if len(args) == 1 and not kwargs:
            ep, = args
        else:
            ep = ws_endpoint(*args, **kwargs)

        if ep.owner is not None:
            raise RuntimeError('an endpoint cannot be added twice')

        with self._route_lock:
            self._app.routes.append(
                WebSocketRoute(ep.rule_string, ep.endpoint, name=ep.__name__))
        ep.owner = self
        return ep
예제 #19
0
 def __init__(
         self,
         aggregation: str,
         host: str = None,
         http_port: int = None,
         exit_round: int = 1,
         participants_count: int = 1,
         ws_size: int = 10 * 1024 * 1024):
     if not host:
         host = Context.get_parameters("AGG_BIND_IP", get_host_ip())
     if not http_port:
         http_port = int(Context.get_parameters("AGG_BIND_PORT", 7363))
     super(
         AggregationServer,
         self).__init__(
         servername=aggregation,
         host=host,
         http_port=http_port,
         ws_size=ws_size)
     self.aggregation = aggregation
     self.participants_count = participants_count
     self.exit_round = max(int(exit_round), 1)
     self.app = FastAPI(
         routes=[
             APIRoute(
                 f"/{aggregation}",
                 self.client_info,
                 response_class=JSONResponse,
             ),
             WebSocketRoute(
                 f"/{aggregation}",
                 BroadcastWs
             )
         ],
     )
     self.app.shutdown = False
예제 #20
0
        return is_basic_mutation(type.of_type)
    return False


def validate_schema(schema):

    for name, mutation in schema.mutation_type.fields.items():
        if name == "_unused":
            continue
        if not is_non_null_basic_mutation(mutation.type):
            raise ValueError(f"{name} Must implement IMutationResponse!")


validate_schema(schema)
graphql_route = GraphQL(schema, middleware=[auth_middleware])
app = Starlette(
    routes=[
        WebSocketRoute("/ws/shell/{uuid:str}", handle_shell),
        Route("/api", graphql_route),
        Route("/api/", graphql_route),
        WebSocketRoute("/api", GraphQL(schema=schema)),
        Mount("/", StaticFilesFallback(directory="static", html=True)),
    ],
    on_startup=[on_startup],
    on_shutdown=[on_shutdown],
    middleware=[
        Middleware(RawContextMiddleware),
        Middleware(RequestCacheMiddleware)
    ],
)
예제 #21
0
def test_standalone_ws_route_does_not_match():
    app = WebSocketRoute("/", ws_helloworld)
    client = TestClient(app)
    with pytest.raises(WebSocketDisconnect):
        with client.websocket_connect("/invalid"):
            pass  # pragma: nocover
예제 #22
0
def test_standalone_ws_route_matches():
    app = WebSocketRoute("/", ws_helloworld)
    client = TestClient(app)
    with client.websocket_connect("/") as websocket:
        text = websocket.receive_text()
        assert text == "Hello, world!"
예제 #23
0
    url = request.url_for("http_endpoint")
    return Response(f"URL: {url}", media_type="text/plain")


class WebSocketEndpoint:
    async def __call__(self, scope, receive, send):
        websocket = WebSocket(scope=scope, receive=receive, send=send)
        await websocket.accept()
        await websocket.send_json(
            {"URL": str(websocket.url_for("websocket_endpoint"))})
        await websocket.close()


mixed_protocol_app = Router(routes=[
    Route("/", endpoint=http_endpoint),
    WebSocketRoute(
        "/", endpoint=WebSocketEndpoint(), name="websocket_endpoint"),
])


def test_protocol_switch():
    client = TestClient(mixed_protocol_app)

    response = client.get("/")
    assert response.status_code == 200
    assert response.text == "URL: http://testserver/"

    with client.websocket_connect("/") as session:
        assert session.receive_json() == {"URL": "ws://testserver/"}

    with pytest.raises(WebSocketDisconnect):
        with client.websocket_connect("/404"):
예제 #24
0
async def background_task(message):
    # send mail, hash a pw, time consuming stuff like that
    await sleep(5)
    print(message)


# starlette functional endpoint
async def test(request):
    task = BackgroundTask(background_task, message="testing background task!")
    return PlainTextResponse(
        "Testing Starlette functional endpoint with a bg task!",
        background=task)


# fastapi routes
routes = [
    Mount("/sanic-api", sanic_app),
    WebSocketRoute("/ws", websocket_endpoint),
    Route('/test', endpoint=test),
    Mount("/", starlette_static_files),
]

app = FastAPI(routes=routes)

# uncomment this to enable the trusted host middleware

#app.add_middleware(
#    TrustedHostMiddleware,
#    allowed_hosts=['example.com', '*.example.com']
#)
예제 #25
0
    pass  # pragma: no cover


def schema(request):
    return schemas.OpenAPIResponse(request=request)


subapp = Starlette(
    routes=[
        Route("/subapp-endpoint", endpoint=subapp_endpoint),
    ]
)

app = Starlette(
    routes=[
        WebSocketRoute("/ws", endpoint=ws),
        Route("/users", endpoint=list_users, methods=["GET", "HEAD"]),
        Route("/users", endpoint=create_user, methods=["POST"]),
        Route("/orgs", endpoint=OrganisationsEndpoint),
        Route("/regular-docstring-and-schema", endpoint=regular_docstring_and_schema),
        Route("/regular-docstring", endpoint=regular_docstring),
        Route("/no-docstring", endpoint=no_docstring),
        Route("/schema", endpoint=schema, methods=["GET"], include_in_schema=False),
        Mount("/subapp", subapp),
    ]
)


def test_schema_generation():
    schema = schemas.get_schema(routes=app.routes)
    assert schema == {
예제 #26
0
</body>
<script type="text/javascript">
const ws = new WebSocket(`ws://${window.location.host}/ws`);
ws.onmessage = (e) => {
  const data = e.data;
  console.log(data);
};
ws.onclose = (e) => {
  console.error(e)
};
</script>
</html>
"""
    return HTMLResponse(html)


app = Starlette(debug=True, routes=[Route("/", index), WebSocketRoute("/ws", hello)])


def test_app():
    from starlette.testclient import TestClient

    c = TestClient(app)
    with c.websocket_connect("/ws") as ws:
        data = ws.receive_text()
        print(data)


if __name__ == "__main__":
    test_app()
예제 #27
0
from starlette.routing import Route, WebSocketRoute


async def homepage(request):
    return FileResponse('./tests/index.html')


class MyElement(CustomElement):
    async def setup(self):
        self.test = "foo"

    async def click(self):
        self.test = "bar"

    async def render(self):
        if self.test == "bar":
            return span[i["baz"], "foobar", b["barbaz"]]
        return h1(on_click=self.click)[self.test]


async def on_accept(websocket: HTMLOverTheAirWebSocket):
    await websocket.dom.render(MyElement)


routes = [
    Route("/", endpoint=homepage),
    WebSocketRoute("/ws", endpoint=HTMLOverTheAirEndpoint(on_accept=on_accept))
]

app = Starlette(routes=routes)
예제 #28
0
class HandledExcAfterResponse:
    def __init__(self, scope):
        pass

    async def __call__(self, receive, send):
        response = PlainTextResponse("OK", status_code=200)
        await response(receive, send)
        raise HTTPException(status_code=406)


router = Router(routes=[
    Route("/runtime_error", endpoint=raise_runtime_error),
    Route("/not_acceptable", endpoint=not_acceptable),
    Route("/not_modified", endpoint=not_modified),
    Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse),
    WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
])

app = ExceptionMiddleware(router)
client = TestClient(app)


def test_server_error():
    with pytest.raises(RuntimeError):
        response = client.get("/runtime_error")

    allow_500_client = TestClient(app, raise_server_exceptions=False)
    response = allow_500_client.get("/runtime_error")
    assert response.status_code == 500
    assert response.text == "Internal Server Error"
예제 #29
0
    def __init__(
        self,
        *,
        engine: Engine = None,
        sdl: str = None,
        graphiql: typing.Union[None, bool, GraphiQL] = True,
        path: str = "/",
        subscriptions: typing.Union[bool, Subscriptions] = None,
        context: dict = None,
        schema_name: str = "default",
    ) -> None:
        if engine is None:
            assert sdl, "`sdl` expected if `engine` not given"
            engine = Engine(sdl=sdl, schema_name=schema_name)

        assert engine, "`engine` expected if `sdl` not given"

        self.engine = engine

        if context is None:
            context = {}

        if graphiql is True:
            graphiql = GraphiQL()
        elif not graphiql:
            graphiql = None

        assert graphiql is None or isinstance(graphiql, GraphiQL)

        if subscriptions is True:
            subscriptions = Subscriptions(path="/subscriptions")
        elif not subscriptions:
            subscriptions = None

        assert subscriptions is None or isinstance(subscriptions,
                                                   Subscriptions)

        routes: typing.List[BaseRoute] = []

        if graphiql and graphiql.path is not None:
            routes.append(Route(graphiql.path, GraphiQLEndpoint))

        routes.append(Route(path, GraphQLEndpoint))

        if subscriptions is not None:
            routes.append(
                WebSocketRoute(subscriptions.path, SubscriptionEndpoint))

        self.router = Router(routes=routes, on_startup=[self.startup])

        config = GraphQLConfig(
            engine=self.engine,
            context=context,
            graphiql=graphiql,
            path=path,
            subscriptions=subscriptions,
        )

        self.app = GraphQLMiddleware(self.router, config=config)

        self._started_up = False
예제 #30
0
from starlette.applications import Starlette
from starlette.routing import Route, WebSocketRoute
from starlette.middleware import Middleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from sqlalchemy import update
import config
from views import Server
from models import Song


ROUTES = [
    WebSocketRoute('/ws', Server)
]


MIDDLEWARES = [
    Middleware(CORSMiddleware, allow_origins=config.CORS_ORIGIN),
    Middleware(GZipMiddleware),
]

if not config.DEBUG:
    MIDDLEWARES.append(Middleware(HTTPSRedirectMiddleware))
    MIDDLEWARES.append(Middleware(TrustedHostMiddleware, allowed_hosts=config.ALLOWED_HOSTS))


app = Starlette(debug=config.DEBUG, routes=ROUTES)
app.state.config = config