Ejemplo n.º 1
0
def init():
    app = Starlette()

    @app.on_event("startup")
    async def async_setup():
        await pg_init()

    @app.exception_handler(JSONDecodeError)
    async def bad_json(request, exc):
        return JSONResponse({'reason': 'invalid json', 'details': str(exc)}, status_code=400)

    @app.exception_handler(InsufficientPermissionsError)
    async def handle_permissions(request, exc):
        return JSONResponse({'reason': 'you are not authorized to do that dave'}, status_code=403)

    # auth stuff
    auth = GoogleAuthBackend(GOOGLE_ID, GOOGLE_SECRET, GOOGLE_ORG)
    app.add_middleware(AuthenticationMiddleware,
                       backend=auth,
                       on_error=auth.on_error)

    app.add_middleware(SessionMiddleware, session_cookie=COOKIE_NAME,
                       secret_key=COOKIE_KEY, https_only=not LOCAL_DEV,
                       max_age=2 * 24 * 60 * 60)  # 2 days

    # sentry stuff
    sentry_sdk.init(dsn=SENTRY_URL, environment=ENV_NAME)
    app.add_middleware(SentryMiddleware)

    async def index_html(request):
        static = pathlib.Path('tmeister/static/index.html')
        return HTMLResponse(static.read_text())

    app.add_route('/api/envs/{name}/toggles', toggles.get_toggle_states_for_env, methods=['GET'])
    app.add_route('/api/toggles', toggles.get_all_toggle_states, methods=['GET'])
    app.add_route('/api/toggles', toggles.set_toggle_state, methods=['PATCH'])
    app.add_route('/api/features', features.get_features)
    app.add_route('/api/features', features.create_feature, methods=['POST'])
    app.add_route('/api/features/{name}', features.delete_feature, methods=['DELETE'])
    app.add_route('/api/envs', environments.get_envs)
    app.add_route('/api/envs', environments.add_env, methods=['POST'])
    app.add_route('/api/envs/{name}', environments.delete_env, methods=['DELETE'])
    app.add_route('/api/auditlog', auditing.get_audit_events)
    app.add_route('/heartbeat', health.get_health)
    app.add_route('/', index_html)

    app.mount('/', app=StaticFiles(directory='tmeister/static'), name='static')

    return app
Ejemplo n.º 2
0

@pytest.fixture(autouse=True)
async def create_database():
    engine = await gino.create_engine('postgresql://*****:*****@localhost/test')
    try:
        print("creating Database")
        await db.gino.create_all(engine)
        yield
    finally:
        await db.gino.drop_all(engine)
        await engine.close()


app = Starlette()
app.add_middleware(DatabaseMiddleware, db=db, database_url=database_url)


@app.route('/users', methods=['GET', 'POST'])
async def users(request):
    if request.method == 'GET':
        users = await User.query.gino.all()
        result = []
        for user in users:
            result.append({'id': user.id, 'name': user.name})
        return JSONResponse({'status': 'success', 'message': result})

    if request.method == 'POST':
        body = await request.json()
        result = await User.create(**body)
        return JSONResponse({'status': 'success', 'message': 'Data Saved'})
Ejemplo n.º 3
0
def create_app():
    app = Starlette()
    app.add_route("/session_initialized", session_initialized)
    app.add_route("/session_not_initialized", session_not_initialized)
    app.add_middleware(DatabaseMiddleware)
    return app
Ejemplo n.º 4
0
            elapsed_time = time.perf_counter() - start
        except Exception as e:
            metric_provider.counter("server.call.exception.counter", tags=tags)
            raise e from None
        else:
            tags.update({"status_code": response.status_code})
            metric_provider.timer("server.call.elapsed",
                                  value=elapsed_time,
                                  tags=tags)
            metric_provider.counter("server.call.counter", tags=tags)

        return response


# we add a middleware class for logging exceptions to Sentry
app.add_middleware(SentryMiddleware)

# we add a middleware class for capturing metrics using Dispatch's metrics provider
app.add_middleware(MetricsMiddleware)

# we install all the plugins
install_plugins()

# we add all the plugin event API routes to the API router
install_plugin_events(api_router)

# we add all API routes to the Web API framework
api.include_router(api_router, prefix="/v1")

# we mount the frontend and app
if STATIC_DIR:
from starlette.responses import PlainTextResponse
from starlette.middleware.cors import CORSMiddleware

import base64
import ssl

# load learner for fast.ai
from fastai import *
from fastai.vision import *
defaults.device = torch.device('cpu')

from io import BytesIO
import uvicorn
app = Starlette()
# !! USE ONLY WHEN JSON SERVER IS DIFFERENT THAN HTTP SERVER !! #
app.add_middleware(CORSMiddleware, allow_origins=['*'])

# To check if server working


@app.route("/ping", methods=["get"])
async def ping(request):
    return JSONResponse({"isWorking": "Yes"})


@app.route("/movies_cold", methods=["POST"])
async def movies_cold(request):
    return PlainTextResponse(m.movies.sample(20).to_json(orient='records'))


@app.route("/movies_get", methods=["POST"])
Ejemplo n.º 6
0
    """
    template = "404.html"
    context = get_request_context(request)
    return templates.TemplateResponse(template, context, status_code=404)


@app.exception_handler(500)
async def server_error(request, exc):
    """
    Return an HTTP 500 page.
    """
    template = "500.html"
    context = get_request_context(request)
    return templates.TemplateResponse(template, context, status_code=500)


BUCKET_NAME = os.environ.get("BUCKET_NAME", None)
REGION_NAME = os.environ.get("REGION_NAME", None)
S3_AWS_ACCESS_KEY_ID = os.environ.get("S3_AWS_ACCESS_KEY_ID")
S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")

app.add_middleware(
    S3StorageMiddleware,
    bucket_name=BUCKET_NAME,
    region_name=REGION_NAME,
    aws_access_key_id=S3_AWS_ACCESS_KEY_ID,
    aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY,
    static_dir="static",
)
handler = Mangum(app)
Ejemplo n.º 7
0
from starlette.applications import Starlette
from starlette.graphql import GraphQLApp
from starlette.middleware.authentication import AuthenticationMiddleware

from graphql.execution.executors.asyncio import AsyncioExecutor

from .models import schema
from .utils.authentication import JWTAuthenticationBackend
from config.settings import (
    DEBUG,
    DATABASE_URL,
    SECRET_KEY,
    JWT_ALGORITHM,
    DatabaseMiddleware,
)

app = Starlette()
app.debug = DEBUG
app.add_middleware(
    AuthenticationMiddleware,
    backend=JWTAuthenticationBackend(SECRET_KEY, JWT_ALGORITHM),
)
app.add_middleware(DatabaseMiddleware, database_url=DATABASE_URL)
app.add_route("/query", GraphQLApp(schema=schema, executor=AsyncioExecutor()))
Ejemplo n.º 8
0
from starlette import status
from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette_prometheus import PrometheusMiddleware, metrics
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware

from src import config, store, task

logger = logging.getLogger(__name__)

app = Starlette()
app.add_middleware(PrometheusMiddleware)
app.add_route("/metrics", metrics)

if config.DEBUG:
    logger.setLevel(logging.DEBUG)
else:
    app.add_middleware(ProxyHeadersMiddleware)
    app.add_middleware(TrustedHostMiddleware,
                       allowed_hosts=config.LOCAL_NETWORKS)

app.add_middleware(CORSMiddleware, allow_origins=config.LOCAL_NETWORKS)


@app.exception_handler(status.HTTP_403_FORBIDDEN)
@app.exception_handler(status.HTTP_404_NOT_FOUND)
@app.exception_handler(status.HTTP_500_INTERNAL_SERVER_ERROR)
Ejemplo n.º 9
0
        if "Authorization" not in request.headers:
            return None

        auth = request.headers["Authorization"]
        try:
            scheme, credentials = auth.split()
            decoded = base64.b64decode(credentials).decode("ascii")
        except (ValueError, UnicodeDecodeError, binascii.Error) as exc:
            raise AuthenticationError("Invalid basic auth credentials")

        username, _, password = decoded.partition(":")
        return AuthCredentials(["authenticated"]), SimpleUser(username)


app = Starlette()
app.add_middleware(AuthenticationMiddleware, backend=BasicAuth())


@app.route("/")
def homepage(request):
    return JSONResponse(
        {
            "authenticated": request.user.is_authenticated,
            "user": request.user.display_name,
        }
    )


@app.route("/dashboard")
@requires("authenticated")
async def dashboard(request):
Ejemplo n.º 10
0
    tasks.add_task(sync_bg_task)
    return PlainTextResponse("Hello, world!", background=tasks)


async def async_bg_task():
    pass


def sync_bg_task():
    pass


routes = [
    Route("/async", run_async_bg_task),
    Route("/sync", run_sync_bg_task),
]

# Generating target applications
target_application = {}

app = Starlette(routes=routes)
app.add_middleware(ASGIStyleMiddleware)
target_application["asgi"] = AsgiTest(app)

app = Starlette(routes=routes)
app.add_middleware(BaseHTTPStyleMiddleware)
target_application["basehttp"] = AsgiTest(app)

app = Starlette(routes=routes)
target_application["none"] = AsgiTest(app)
Ejemplo n.º 11
0
from starlette.staticfiles import StaticFiles
from starlette.responses import HTMLResponse
from starlette.templating import Jinja2Templates
from starlette.middleware.gzip import GZipMiddleware
from hello_v1.endpoints import hello_api as hello_v1
from hello_v2.endpoints import hello_api as hello_v2

import uvicorn

import logging
logger = logging.getLogger(__name__)

templates = Jinja2Templates(directory='templates')

app = Starlette(debug=True)
app.add_middleware(GZipMiddleware, minimum_size=1000)

app.mount('/static', StaticFiles(directory='statics'), name='static')
app.mount('/hello/v1', app=hello_v1)
app.mount('/hello/v2', app=hello_v2)


@app.route('/')
async def homepage(request):
    template = "index.html"
    context = {"request": request}
    return templates.TemplateResponse(template, context)


@app.route('/error')
async def error(request):
Ejemplo n.º 12
0
# -*- coding: utf-8 -*-
from unittest.mock import MagicMock, call, patch

from starlette.applications import Starlette
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import PlainTextResponse
from starlette.testclient import TestClient

from keycloak.constants import GrantTypes
from keycloak.extensions.starlette import AuthenticationMiddleware
from keycloak.utils import auth_header

app = Starlette()
app.add_middleware(
    AuthenticationMiddleware,
    callback_url="http://testserver/kc/callback",
    login_redirect_uri="/howdy",
    logout_redirect_uri="/logout",
)
app.add_middleware(SessionMiddleware, secret_key="key0123456789")


@app.route("/howdy")
async def howdy(request):
    return PlainTextResponse("Howdy!")


@app.route("/logout")
async def logout(request):
    return PlainTextResponse("Logged out!")

Ejemplo n.º 13
0
    def app(self):
        app_ = Starlette()
        app_.add_middleware(PrometheusMiddleware, filter_unhandled_paths=True)
        app_.add_route("/metrics/", metrics)

        return app_
Ejemplo n.º 14
0
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from starlette.routing import Route

export_file_url = 'https://drive.google.com/uc?id=190SxQMkQO-7HX46Pw7URQKZrBrQTqL8v&export=download'  #'https://drive.google.com/uc?export=download&id=1U6vmC0eY_ejOvFvHIjXUsvI7Jsn31SRd'
export_file_name = 'export.pkl'

classes = ['apple', 'banana', 'strawberry']
path = Path(__file__).parent

templates = Jinja2Templates(directory=str('app/templates'))

app = Starlette()
app.add_middleware(CORSMiddleware,
                   allow_origins=['*'],
                   allow_headers=['*'],
                   allow_methods=['*'],
                   allow_credentials=['*'])
# app.mount('/static', StaticFiles(directory='app/static'))
app.mount('/templates', StaticFiles(directory='app/templates'))


async def download_file(url, dest):
    if dest.exists(): return
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            data = await response.read()
            with open(dest, 'wb') as f:
                f.write(data)

Ejemplo n.º 15
0
except KeyError:  # pragma: no cover
    pytest.skip("DATABASE_URL is not set", allow_module_level=True)

metadata = sqlalchemy.MetaData()

notes = sqlalchemy.Table(
    "notes",
    metadata,
    sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
    sqlalchemy.Column("text", sqlalchemy.String),
    sqlalchemy.Column("completed", sqlalchemy.Boolean),
)

app = Starlette()
app.add_middleware(
    DatabaseMiddleware, database_url=DATABASE_URL, rollback_on_shutdown=True
)


@pytest.fixture(autouse=True, scope="module")
def create_test_database():
    engine = sqlalchemy.create_engine(DATABASE_URL)
    metadata.create_all(engine)
    yield
    engine.execute("DROP TABLE notes")


@app.route("/notes", methods=["GET"])
async def list_notes(request):
    query = notes.select()
    results = await request.database.fetchall(query)
Ejemplo n.º 16
0
def setup_middleware(app: Starlette) -> None:
    app.add_middleware(AuthMiddleware)
Ejemplo n.º 17
0
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient


class CustomMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request, call_next):
        response = await call_next(request)
        response.headers["Custom-Header"] = "Example"
        return response


app = Starlette()
app.add_middleware(CustomMiddleware)


@app.route("/")
def homepage(request):
    return PlainTextResponse("Homepage")


@app.route("/exc")
def exc(request):
    raise Exception()


@app.route("/no-response")
class NoResponse:
    def __init__(self, scope, receive, send):
Ejemplo n.º 18
0
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import HTMLResponse, JSONResponse
from starlette.staticfiles import StaticFiles

export_file_url = (
    "https://drive.google.com/uc?export=download&id=1-1qWJ8qX_eRZfap2tC4RDVbuijB39oDS"
)
export_file_name = "export.pkl"

classes = ["bracelet", "earring", "necklace", "ring"]
path = Path(__file__).parent

app = Starlette()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_headers=["X-Requested-With", "Content-Type"],
)
app.mount("/static", StaticFiles(directory="app/static"))


class SaveFeatures:
    features = None

    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_fn)
        self.features = None

    def hook_fn(self, module, input, output):
        out = output.detach().cpu().numpy()
        if isinstance(self.features, type(None)):
Ejemplo n.º 19
0
from prometheus_client import Counter

ERROR_COUNT = Counter(
    "failed_call", "Counts of calls that failed", ("to",))


async def main(request: Request) -> JSONResponse:
    async with httpx.AsyncClient(timeout=1.0) as c:
        try:
            r = await c.get("http://middle:8001")
            value = r.json().get('value')
            request.app.last_value = value
            return JSONResponse({"value": value})
        except httpx.TransportError as x:
            ERROR_COUNT.labels("middle").inc()
            logging.error(
                "Failed to talk with the middle service", exc_info=True)
            return JSONResponse(
                {
                    "error": x.__class__.__name__,
                    "value": request.app.last_value
                }, status_code=200)


app = Starlette(debug=True, routes=[
    Route('/', main),
])
app.last_value = 0
app.add_middleware(PrometheusMiddleware)
app.add_route("/metrics", handle_metrics)
Ejemplo n.º 20
0
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse

import uvicorn
from starlette_context import context
from starlette_context.middleware import ContextMiddleware

app = Starlette(debug=True)


@app.route("/")
async def index(request: Request):
    context["view"] = True
    return JSONResponse(context.data)


class ContextFromMiddleware(ContextMiddleware):
    async def set_context(self, request: Request) -> dict:
        return {"middleware": True}


app.add_middleware(ContextFromMiddleware)

uvicorn.run(app, host="0.0.0.0")
Ejemplo n.º 21
0
@router.route("/post", methods=["POST"])
async def post(request: HTTPConnection):
    form = await request.form()
    form = PostForm(form)

    name = form.name.data
    message = form.message.data

    posts = request.session.get("posts")
    if posts is None:
        posts = []
        request.session["posts"] = posts

    posts.append((name, message))

    return templates.TemplateResponse(
        "main.j2",
        {
            "request": request,
            "form": form,
            "posts": posts,
            "post_create": request.url_for("post"),
        },
    )


app = Starlette(routes=[Mount("/ssti", app=router)])

app.add_middleware(SessionMiddleware, secret_key="doesntmatter")
Ejemplo n.º 22
0
async def index_html(request):
    logging.info('index page request')

    return html_templates.TemplateResponse('index.html', {
        'request': request,
    })


routes = [
    Route('/', index_html, methods=['GET'], name='homepage'),
]

app = Starlette(debug=False, routes=routes)

app.add_middleware(
    TimingMiddleware,
    client=PrintTimings(),
    metric_namer=StarletteScopeToName(prefix='', starlette_app=app),
)

app_log = logging.StreamHandler()
formatter = logging.Formatter(
    '%(asctime)s %(process)s %(levelname)s %(name)s %(message)s'
)  # noqa: WPS323
app_log.setFormatter(formatter)

logger = logging.getLogger()
logger.handlers = []
logger.addHandler(app_log)
logger.setLevel(logging.INFO)
Ejemplo n.º 23
0
import os

from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient

app = Starlette()

app.add_middleware(TrustedHostMiddleware,
                   allowed_hosts=["testserver", "*.example.org"])


@app.exception_handler(500)
async def error_500(request, _exc):
    return JSONResponse({"detail": "Server Error"}, status_code=500)


@app.exception_handler(405)
async def method_not_allowed(request, _exc):
    return JSONResponse({"detail": "Custom message"}, status_code=405)


@app.exception_handler(HTTPException)
async def http_exception(request, exc):
    return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
Ejemplo n.º 24
0
#       The app will not start with this value, forcing the users to set their onw secret
#       key. Therefore, the value is used as default here as well.
SECRET_KEY = webgui_config('SECRET_KEY',
                           cast=Secret,
                           default="PutSomethingRandomHere")
WEBGUI_PORT = webgui_config('PORT', cast=int, default=8000)
WEBGUI_HOST = webgui_config('HOST', default='0.0.0.0')
DEBUG_MODE = webgui_config('DEBUG', cast=bool, default=True)

app = Starlette(debug=DEBUG_MODE)
# Don't check the existence of the static folder because the wrong parent folder is used if the
# source code is parsed by sphinx. This would raise an exception and lead to failure of sphinx.
app.mount('/static',
          StaticFiles(directory='webinterface/statics', check_dir=False),
          name='static')
app.add_middleware(AuthenticationMiddleware, backend=SessionAuthBackend())
app.add_middleware(SessionMiddleware,
                   secret_key=SECRET_KEY,
                   session_cookie="mercure_session")
app.mount("/modules", modules.modules_app)
app.mount("/queue", queue.queue_app)


async def async_run(cmd):
    """Executes the given command in a way compatible with ayncio."""
    proc = await asyncio.create_subprocess_shell(
        cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)

    stdout, stderr = await proc.communicate()
    return proc.returncode, stdout, stderr
Ejemplo n.º 25
0
from gino import Gino
from starlette.applications import Starlette
from starlette.responses import JSONResponse

from starlette_gino.middleware import DatabaseMiddleware

db = Gino()
app = Starlette(debug=True)
app.add_middleware(DatabaseMiddleware,
                   db=db,
                   database_url='postgresql://*****:*****@localhost/test')


class User(db.Model):
    __tablename__ = 'users'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String)
    fullname = db.Column(db.String)

    def __str__(self):
        return "%s - %s" % (self.id, self.name)


class Address(db.Model):
    __tablename__ = 'addresses'

    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(None, db.ForeignKey('users.id'))
    email_address = db.Column(db.String, nullable=False)
Ejemplo n.º 26
0
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route

from starlette_session import SessionMiddleware


async def setup_session(request: Request) -> JSONResponse:
    request.session.update({"data": "session_data"})
    return JSONResponse({"session": request.session})


async def clear_session(request: Request):
    request.session.clear()
    return JSONResponse({"session": request.session})


async def view_session(request: Request) -> JSONResponse:
    return JSONResponse({"session": request.session})


routes = [
    Route("/setup_session", endpoint=setup_session),
    Route("/clear_session", endpoint=clear_session),
    Route("/view_session", endpoint=view_session),
]

app = Starlette(debug=True, routes=routes)
app.add_middleware(SessionMiddleware, secret_key="secret", cookie_name="cookie22")
Ejemplo n.º 27
0
import graphene
from graphql.execution.executors.asyncio import AsyncioExecutor
from starlette.applications import Starlette
from starlette.graphql import GraphQLApp
from starlette.middleware.cors import CORSMiddleware

from blenheim.schema.schema import Query, Mutations
from blenheim.schema.settings.settings import SettingsMutations

app = Starlette()
app.add_middleware(CORSMiddleware,
                   allow_origins=['*'],
                   allow_headers=['*'],
                   allow_methods=["POST"])
# noinspection PyTypeChecker
app.add_route(
    '/',
    GraphQLApp(schema=graphene.Schema(query=Query, mutation=Mutations),
               executor_class=AsyncioExecutor))
Ejemplo n.º 28
0
        if "Authorization" not in request.headers:
            return None

        auth = request.headers["Authorization"]
        try:
            scheme, credentials = auth.split()
            decoded = base64.b64decode(credentials).decode("ascii")
        except (ValueError, UnicodeDecodeError, binascii.Error):
            raise AuthenticationError("Invalid basic auth credentials")

        username, _, password = decoded.partition(":")
        return AuthCredentials(["authenticated"]), SimpleUser(username)


app = Starlette()
app.add_middleware(AuthenticationMiddleware, backend=BasicAuth())


@app.route("/")
def homepage(request):
    return JSONResponse(
        {
            "authenticated": request.user.is_authenticated,
            "user": request.user.display_name,
        }
    )


@app.route("/dashboard")
@requires("authenticated")
async def dashboard(request):
Ejemplo n.º 29
0
from starlette.staticfiles import StaticFiles
from starlette.middleware.cors import CORSMiddleware
import uvicorn, aiohttp, asyncio
from io import BytesIO

from fastai import *
from fastai.vision import *

model_file_url = 'https://drive.google.com/uc?export=download&id=1hruHBDzJ0lPegIkAK3PJYyxHfjFR2M68'
model_file_name = 'model'
classes = ['cheetah', 'leopard']
path = Path(__file__).parent

app = Starlette()
app.add_middleware(CORSMiddleware,
                   allow_origins=['*'],
                   allow_headers=['X-Requested-With', 'Content-Type'])
app.mount('/static', StaticFiles(directory='app/static'))


async def download_file(url, dest):
    if dest.exists(): return
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            data = await response.read()
            with open(dest, 'wb') as f:
                f.write(data)


async def setup_learner():
    await download_file(model_file_url,
Ejemplo n.º 30
0
NECRO_BASE = os.getenv('NECRO_BASE')
if not NECRO_BASE:
    log.warning('NECRO_BASE env var is not defined. Elasticsearch queries to '
                'necropolis will not work!')


def http_exception_handler(request, exc):
    # We assume that an HTTPException, which has been raised by our code,
    # has already been logged.
    return JSONResponse(exc.detail, status_code=exc.status_code)


def validation_exception_handler(request, exc):
    return JSONResponse(exc.messages, status_code=400)


def misc_exception_handler(request, exc):
    log.exception(exc)
    return JSONResponse('Unexpected error', status_code=500)


app = Starlette(debug=False)
app.mount('', Router(routes.routes))
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(ValidationError, validation_exception_handler)
app.add_exception_handler(Exception, misc_exception_handler)
app.add_middleware(CORSMiddleware,
                   allow_origins=['*'],
                   allow_methods=['GET', 'POST'])
Ejemplo n.º 31
0
from starlette.staticfiles import StaticFiles
from starlette.middleware.cors import CORSMiddleware
import uvicorn, aiohttp, asyncio
from io import BytesIO

from fastai.vision import ImageDataBunch, create_cnn, open_image, get_transforms, imagenet_stats, models
from fastai import Path

model_file_url = 'https://www.dropbox.com/s/y4kl2gv1akv7y4i/stage-2.pth?raw=1'
model_file_name = 'model'

classes = ['black', 'grizzly', 'teddys']
path = Path(__file__).parent

app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_headers=['X-Requested-With', 'Content-Type'])
app.mount('/static', StaticFiles(directory='app/static'))

async def download_file(url, dest):
    if dest.exists(): return
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            data = await response.read()
            with open(dest, 'wb') as f: f.write(data)

async def setup_learner():
    await download_file(model_file_url, path/'models'/f'{model_file_name}.pth')
    data_bunch = ImageDataBunch.single_from_classes(path, classes,
        tfms=get_transforms(), size=224).normalize(imagenet_stats)
    learn = create_cnn(data_bunch, models.resnet34, pretrained=False)
    learn.load(model_file_name)
Ejemplo n.º 32
0
async def index(request):
    results = "Home page"
    return templates.TemplateResponse("index.html", {
        "request": request,
        "results": results
    })


routes = [
    Route("/", index),
]

app = Starlette(debug=True, routes=routes)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/accounts", accounts_routes)
app.add_middleware(AuthenticationMiddleware, backend=UserAuthentication())
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)


# middleware for secure headers
@app.middleware("http")
async def set_secure_headers(request, call_next):
    response = await call_next(request)
    secure_headers.starlette(response)
    return response


@app.exception_handler(404)
async def not_found(request, exc):
    """
    Return an HTTP 404 page.