from unittest import mock

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.middleware.sessions import SessionMiddleware

from fastapi_auth.routers import get_social_router

from .utils import ACCESS_COOKIE_NAME, REFRESH_COOKIE_NAME, MockAuthBackend

app = FastAPI()

app.add_middleware(SessionMiddleware, secret_key="SECRET", max_age=10)

router = get_social_router(
    None,
    MockAuthBackend(None, None, None, None, None),
    False,
    "RU",
    "http://127.0.0.1",
    ACCESS_COOKIE_NAME,
    REFRESH_COOKIE_NAME,
    None,
    None,
    ["google", "facebook", "vk"],
    {
        "google": {
            "id": "id",
            "secret": "secret",
        },
Beispiel #2
0
# When validation fails, the response status
# is set to 422 Unprocessable Entity.
class Dog(BaseModel):
    id: Optional[str] = None
    breed: str
    name: str


id = str(uuid.uuid4())
dogs = {}
dogs[id] = {'id': id, 'breed': 'Whippet', 'name': 'Comet'}

app = FastAPI()
app.add_middleware(CORSMiddleware,
                   allow_origins=['*'],
                   allow_credentials=True,
                   allow_methods=['*'],
                   allow_headers=['*'])


@app.get('/dog')
def get_dogs():
    return list(dogs.values())


@app.get('/dog/{id}')
def get_dog(id: str):
    if id in dogs:
        return dogs[id]
    else:
        return Response(status_code=status.HTTP_404_NOT_FOUND)
Beispiel #3
0
from fastapi.middleware.cors import CORSMiddleware

import requests

from app.routers import spotify, users
from app.utils.logger import logger
from app.models import models
from app.database import engine

models.Base.metadata.create_all(bind=engine)

app = FastAPI(redoc_url=None)

app.include_router(spotify.router)
app.include_router(users.router)

app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        os.getenv('FRONTEND_URL'),
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get('/hello')
def hello():
    return {'hello': 'sometxt'}
Beispiel #4
0
        "x-code-samples"] = devices_samples
    openapi_schema["paths"]["/api/queries"]["get"][
        "x-code-samples"] = queries_samples

    app.openapi_schema = openapi_schema
    return app.openapi_schema


CORS_ORIGINS = params.cors_origins.copy()
if params.developer_mode:
    CORS_ORIGINS.append(URL_DEV)

# CORS Configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=CORS_ORIGINS,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)

app.add_api_route(
    path="/api/info",
    endpoint=info,
    methods=["GET"],
    response_model=InfoResponse,
    response_class=JSONResponse,
    summary=params.docs.info.summary,
    description=params.docs.info.description,
    tags=[params.docs.info.title],
)

app.add_api_route(
Beispiel #5
0
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import Response

from tasks.config import ALLOWED_HOST
from tasks.database import SessionLocal
from tasks.views import router

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=[ALLOWED_HOST],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
    response = Response("Internal Server Error", status_code=500)
    try:
        request.state.db = SessionLocal()
        response = await call_next(request)
    finally:
        request.state.db.close()
    return response

Beispiel #6
0
from fastapi import FastAPI, Query

# for CORS
from starlette.middleware.cors import CORSMiddleware

from stamps import solutions_for_price as solve

app = FastAPI()

# for CORS
app.add_middleware(CORSMiddleware, allow_origins=['*'])


@app.get('/')
def index(
        price: int,
        stamps: str = Query(..., regex=r'^\d{1,3}(,\d{1,3})*$'),
):
    # convert from string "x,y,z" to list of ints [x, y, x]
    stamps = [int(s) for s in stamps.split(',')]

    combinations_seen = set()
    results = [(stmv, solv) for _, stmv, solv in solve(price, stamps)]

    return {'stamps': stamps, 'price': price, 'results': results}
    model_uri: str
    feature_rule: int
    mlflow_enable: bool
    mlflow_server: str

    class Config:
        env_file = ".env"


@lru_cache()
def get_settings():
    return Settings()


app = FastAPI()
app.add_middleware(PrometheusMiddleware)
app.add_route("/metrics", handle_metrics)


@app.get("/prediction")
async def prediction(petal_length: float,
                     petal_width: float,
                     sepal_length: float,
                     sepal_width: float,
                     settings: Settings = Depends(get_settings)):
    one_iris = np.array(
        [[petal_length, petal_width, sepal_length, sepal_width]])
    featured_iris = one_iris[:, settings.feature_rule:]

    model = None
Beispiel #8
0
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from database.postgres import PostgresDB
from views.message_views import router as message_router
from views.order_key_set_views import router as oks_router
from views.ws_views import router as ws_router

app = FastAPI()

app.add_middleware(CORSMiddleware,
                   allow_origins=['*'],
                   allow_methods=("GET", "POST", "OPTIONS"),
                   allow_headers=("*", ),
                   allow_credentials=True)


@app.on_event("startup")
async def on_startup():
    await PostgresDB.connect()


@app.on_event("shutdown")
async def on_shutdown():
    await PostgresDB.disconnect()


routers = {
    "/oks": oks_router,
    "/order_chat": message_router,
    "/ws": ws_router,
Beispiel #9
0
        return await call_next(request)


####
#  #
####

logging.basicConfig(
    filename="log.log",
    level=logging.INFO,
    format=f'%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s')
logger = logging.getLogger(__name__)
app = FastAPI(debug=FastAPISettings.DEBUG)

app.add_middleware(CORSMiddleware,
                   allow_origins=CORSSettings.ALLOW_ORIGINS,
                   allow_methods=['*'],
                   allow_headers=['*'])

if UvicornSettings.MAX_CONTENT_SIZE:
    app.add_middleware(LimitPostContentSizeMiddleware,
                       max_upload_size=UvicornSettings.MAX_CONTENT_SIZE)


@app.on_event('startup')
async def startup():
    logger.info('-- STARTING UP --')
    print('-- STARTING UP --')
    from database.db import initialize_db
    initialize_db()
    from resources.routes import initialize_routes
    initialize_routes(app)
Beispiel #10
0
from fastapi.middleware.cors import CORSMiddleware
from src import database
from src import config
from src import endpoints
from src.websocket_manager import subscribe_for_client_signup

app = FastAPI(title="API")

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=[
        "OPTIONS",
        "GET",
        "POST",
        "PUT",
    ],
    allow_headers=["*"],
)
app.include_router(endpoints.router, tags=["bonus account"])


@app.on_event("startup")
async def startup_event():
    database.init_db()
    # await subscribe_for_client_signup()


if __name__ == "__main__":
Beispiel #11
0
    logger.debug(f"Creating OpenApi Docs with server {app.state.servers}")
    openapi_schema = get_openapi(
        title=app.title,
        description=app.description,
        servers=app.state.servers,
        version="2.0.0",
        routes=app.routes,
    )
    # openapi_schema['info']['servers']=app.state.servers
    app.openapi_schema = openapi_schema
    return app.openapi_schema


app.openapi = custom_openapi

app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.add_middleware(StripParametersMiddleware)
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900")
app.add_middleware(TotalTimeMiddleware)
app.add_middleware(GetHostMiddleware)


class OpenAQValidationResponseDetail(BaseModel):
    loc: List[str] = None
Beispiel #12
0
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware

from .api.routes import router as api_router
from .db.mongodb_utils import connect_to_mongo, close_mongo_connection

app = FastAPI(title="KoAirY API",
              docs_url="/koairy/api/docs",
              openapi_url="/koairy/api/openapi.json")

app.add_middleware(CORSMiddleware,
                   allow_origins="*",
                   allow_methods=["*"],
                   allow_headers=["*"])

app.add_event_handler("startup", connect_to_mongo)
app.add_event_handler("shutdown", close_mongo_connection)

app.include_router(api_router, prefix="/koairy/api")
Beispiel #13
0
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api.routers import router
from mangum import Mangum

app = FastAPI()

app.add_middleware(CORSMiddleware,
                   allow_origins=["*"],
                   allow_credentials=True,
                   allow_methods=["*"],
                   allow_headers=["*"],
                   expose_headers=["Access-Control-Allow-Origin"])

app.include_router(router)


@app.get("/")
def read_root():
    return {"Hello": "from Ah Hao"}


"""
    Required for API Gateway Proxy Integration
"""


def handler(event, context):
    asgi_handler = Mangum(app)
    response = asgi_handler(event, context)
Beispiel #14
0
        db.close()


# init testing DB
# database.initBase(database.SessionLocal())

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="tokens")

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origin_regex=r"(http.*localhost.*|https?:\/\/.*cardmatching.ovh.*)",
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def verify_password(plain_password: str, hashed_password: str) -> bool:
    return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> bool:
    return pwd_context.hash(password)


def authenticate_user(db: Session, username: str,
                      password: str) -> Optional[models.UserSensitive]:
Beispiel #15
0
    labeled_dataframe = pd.read_pickle(labels_pickle)

    print("Downloading train indices file from GCP...")
    train_indices_file = load_file_from_gcp('predict/train_indexes.joblib')
    train_indices = joblib.load(train_indices_file)

# Starting API server and uploading our catalogue to memory
app = FastAPI()

database = pd.read_csv(catalogue, encoding='unicode_escape')
database = database[database['FORM'] == 'painting']

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)


# To check if server is up and running
@app.get("/")
def index():
    return {"greeting": "Hello world"}


# Receives the file from the frontend or Telegram bot
@app.post("/uploadfile/")
def create_upload_file(file: UploadFile = File(...),
                       nsimilar: int = Form(...),
Beispiel #16
0
import os

DEV_ENV = os.getenv("DEV_ENV")
app = FastAPI(title=config.PROJECT_NAME,
              docs_url="/api/docs",
              openapi_url="/api")

app.mount("/api/uploads", StaticFiles(directory="uploads"), name="uploads")

# Go to localhost:8000/api/coverage/index.html to see coverage report
# app.mount("/api/coverage", StaticFiles(directory="htmlcov"), name="htmlcov")

# Use HTTPS in production
if not DEV_ENV:
    app.add_middleware(HTTPSRedirectMiddleware)
    origins = [
        "https://conectar-frontend.vercel.app",
        "conectar-frontend.vercel.app",
        "https://boraconectar.com",
    ]

    app.add_middleware(
        CORSMiddleware,
        allow_origins=origins,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["Content-Type", "Accept", "authorization"],
    )

Beispiel #17
0
                       tags=["SpatioTemporal Asset Catalog"])

if not api_settings.disable_mosaic:
    app.include_router(mosaic.router,
                       prefix="/mosaicjson",
                       tags=["MosaicJSON"])

app.include_router(tms.router, tags=["TileMatrixSets"])
add_exception_handlers(app, DEFAULT_STATUS_CODES)

# Set all CORS enabled origins
if api_settings.cors_origins:
    app.add_middleware(
        CORSMiddleware,
        allow_origins=api_settings.cors_origins,
        allow_credentials=True,
        allow_methods=["GET"],
        allow_headers=["*"],
    )

app.add_middleware(BrotliMiddleware, minimum_size=0, gzip_fallback=True)
app.add_middleware(CacheControlMiddleware,
                   cachecontrol=api_settings.cachecontrol)
app.add_middleware(TotalTimeMiddleware)
if api_settings.debug:
    app.add_middleware(LoggerMiddleware)


@app.get("/ping", description="Health Check", tags=["Health Check"])
def ping():
    """Health check."""
Beispiel #18
0
app = FastAPI()

origins = [
    'http://*****:*****@app.get('/')
def alive():
    return {'info': 'alive'}


class NewGame(BaseModel):
    host: str
Beispiel #19
0
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .routers import songs, composers, collections
from .config import allowed_origins


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=allowed_origins,
    allow_methods=['*']
)

app.include_router(songs.router)
app.include_router(composers.router)
app.include_router(collections.router)

@app.get('/')
def root():
    return 'Hello there!'
Beispiel #20
0
# FastAPI
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi_sqlalchemy import DBSessionMiddleware  # middleware helper

# Database
from common.database import init_db, shutdown_db, fa_users_db, engine
from common.utils import vars, SECRET, utcnow

app = FastAPI()
# app.secret_key = SECRET
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)
app.add_middleware(DBSessionMiddleware, db_url=vars.DB_FULL)
# 9131155e: attempted log-filtering


@app.on_event("startup")
async def startup():
    await fa_users_db.connect()
    init_db()


@app.on_event("shutdown")
async def shutdown_session():
    await fa_users_db.disconnect()
    shutdown_db()
Beispiel #21
0
from starlette.middleware.cors import CORSMiddleware

from app.api.api_v1.api import api_router
from app.core.config import settings
from app.initial_data import init_in_memory_songs
from app.core.exceptions import apply_exception_handlers

# Bootstrapping section. Might move this in a separate script
init_in_memory_songs()

app = FastAPI(
    title=settings.PROJECT_NAME,
    openapi_url=f"{settings.API_V1_STR}/openapi.json",
)

apply_exception_handlers(app)

# Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS:
    app.add_middleware(
        CORSMiddleware,
        allow_origins=[
            str(origin) for origin in settings.BACKEND_CORS_ORIGINS
        ],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

app.include_router(api_router, prefix=settings.API_V1_STR)
Beispiel #22
0
Datei: app.py Projekt: srbhr/jina
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
):
    """
    Get the app from FastAPI as the REST interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :return: fastapi app
    """
    with ImportExtensions(required=True):
        from fastapi import FastAPI
        from starlette.requests import Request
        from fastapi.responses import HTMLResponse
        from fastapi.middleware.cors import CORSMiddleware
        from jina.serve.runtimes.gateway.http.models import (
            JinaStatusModel,
            JinaRequestModel,
            JinaEndpointRequestModel,
            JinaResponseModel,
        )

    docs_url = '/docs'
    app = FastAPI(
        title=args.title or 'My Jina Service',
        description=args.description or
        'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` '
        'to customize this text.',
        version=__version__,
        docs_url=docs_url if args.default_swagger_ui else None,
    )

    if args.cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=['*'],
            allow_credentials=True,
            allow_methods=['*'],
            allow_headers=['*'],
        )
        logger.warning(
            'CORS is enabled. This service is now accessible from any website!'
        )

    from jina.serve.stream import RequestStreamer
    from jina.serve.runtimes.gateway.request_handling import (
        handle_request,
        handle_result,
    )

    streamer = RequestStreamer(
        args=args,
        request_handler=handle_request(graph=topology_graph,
                                       connection_pool=connection_pool),
        result_handler=handle_result,
    )
    streamer.Call = streamer.stream

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    openapi_tags = []
    if not args.no_debug_endpoints:
        openapi_tags.append({
            'name':
            'Debug',
            'description':
            'Debugging interface. In production, you should hide them by setting '
            '`--no-debug-endpoints` in `Flow`/`Gateway`.',
        })

        @app.get(
            path='/status',
            summary='Get the status of Jina service',
            response_model=JinaStatusModel,
            tags=['Debug'],
        )
        async def _status():
            """
            Get the status of this Jina service.

            This is equivalent to running `jina -vf` from command line.

            .. # noqa: DAR201
            """
            _info = get_full_version()
            return {
                'jina': _info[0],
                'envs': _info[1],
                'used_memory': used_memory_readable(),
            }

        @app.post(
            path='/post',
            summary='Post a data request to some endpoint',
            response_model=JinaResponseModel,
            tags=['Debug']
            # do not add response_model here, this debug endpoint should not restricts the response model
        )
        async def post(body: JinaEndpointRequestModel):
            """
            Post a data request to some endpoint.

            This is equivalent to the following:

                from jina import Flow

                f = Flow().add(...)

                with f:
                    f.post(endpoint, ...)

            .. # noqa: DAR201
            .. # noqa: DAR101
            """
            # The above comment is written in Markdown for better rendering in FastAPI
            from jina.enums import DataInputType

            bd = body.dict()  # type: Dict
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs):
        """Exposing an executor endpoint to http endpoint
        :param exec_endpoint: the executor endpoint
        :param http_path: the http endpoint
        :param kwargs: kwargs accepted by FastAPI
        """

        # set some default kwargs for richer semantics
        # group flow exposed endpoints into `customized` group
        kwargs['tags'] = kwargs.get('tags', ['Customized'])
        kwargs['response_model'] = kwargs.get(
            'response_model',
            JinaResponseModel,  # use standard response model by default
        )
        kwargs['methods'] = kwargs.get('methods', ['POST'])

        @app.api_route(path=http_path or exec_endpoint,
                       name=http_path or exec_endpoint,
                       **kwargs)
        async def foo(body: JinaRequestModel):
            from jina.enums import DataInputType

            bd = body.dict() if body else {'data': None}
            bd['exec_endpoint'] = exec_endpoint
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    if not args.no_crud_endpoints:
        openapi_tags.append({
            'name':
            'CRUD',
            'description':
            'CRUD interface. If your service does not implement those interfaces, you can should '
            'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.',
        })
        crud = {
            '/index': {
                'methods': ['POST']
            },
            '/search': {
                'methods': ['POST']
            },
            '/delete': {
                'methods': ['DELETE']
            },
            '/update': {
                'methods': ['PUT']
            },
        }
        for k, v in crud.items():
            v['tags'] = ['CRUD']
            v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.'
            expose_executor_endpoint(exec_endpoint=k, **v)

    if openapi_tags:
        app.openapi_tags = openapi_tags

    if args.expose_endpoints:
        endpoints = json.loads(args.expose_endpoints)  # type: Dict[str, Dict]
        for k, v in endpoints.items():
            expose_executor_endpoint(exec_endpoint=k, **v)

    if not args.default_swagger_ui:

        async def _render_custom_swagger_html(req: Request) -> HTMLResponse:
            import urllib.request

            swagger_url = 'https://api.jina.ai/swagger'
            req = urllib.request.Request(swagger_url,
                                         headers={'User-Agent': 'Mozilla/5.0'})
            with urllib.request.urlopen(req) as f:
                return HTMLResponse(f.read().decode())

        app.add_route(docs_url,
                      _render_custom_swagger_html,
                      include_in_schema=False)

    async def _get_singleton_result(request_iterator) -> Dict:
        """
        Streams results from AsyncPrefetchCall as a dict

        :param request_iterator: request iterator, with length of 1
        :return: the first result from the request iterator
        """
        async for k in streamer.stream(request_iterator=request_iterator):
            request_dict = k.to_dict()
            return request_dict

    return app
Beispiel #23
0
def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'):
    """
    Get the app from FastAPI as the REST interface.

    :param args: passed arguments.
    :param logger: Jina logger.
    :return: fastapi app
    """
    with ImportExtensions(required=True):
        from fastapi import FastAPI, WebSocket, Body
        from fastapi.responses import JSONResponse
        from fastapi.middleware.cors import CORSMiddleware
        from starlette.endpoints import WebSocketEndpoint
        from starlette import status
        from starlette.types import Receive, Scope, Send
        from starlette.responses import StreamingResponse
        from .models import (
            JinaStatusModel,
            JinaIndexRequestModel,
            JinaDeleteRequestModel,
            JinaUpdateRequestModel,
            JinaSearchRequestModel,
        )

    app = FastAPI(
        title='Jina',
        description='REST interface for Jina',
        version=__version__,
    )
    app.add_middleware(
        CORSMiddleware,
        allow_origins=['*'],
        allow_credentials=True,
        allow_methods=['*'],
        allow_headers=['*'],
    )
    zmqlet = AsyncZmqlet(args, default_logger)
    servicer = AsyncPrefetchCall(args, zmqlet)

    def error(reason, status_code):
        """
        Get the error code.

        :param reason: content of error
        :param status_code: status code
        :return: error in JSON response
        """
        return JSONResponse(content={'reason': reason},
                            status_code=status_code)

    @app.on_event('shutdown')
    def _shutdown():
        zmqlet.close()

    @app.on_event('startup')
    async def startup():
        """Log the host information when start the server."""
        default_logger.info(f'''
    Jina REST interface
    💬 Swagger UI:\thttp://localhost:{args.port_expose}/docs
    📚 Redoc     :\thttp://localhost:{args.port_expose}/redoc
        ''')
        from jina import __ready_msg__

        default_logger.success(__ready_msg__)

    @app.get(
        path='/status',
        summary='Get the status of Jina',
        response_model=JinaStatusModel,
        tags=['jina'],
    )
    async def _status():
        _info = get_full_version()
        return {
            'jina': _info[0],
            'envs': _info[1],
            'used_memory': used_memory_readable(),
        }

    @app.post(path='/api/{mode}', deprecated=True)
    async def api(mode: str, body: Any = Body(...)):
        """
        Request mode service and return results in JSON, a deprecated interface.

        :param mode: INDEX, SEARCH, DELETE, UPDATE, CONTROL, TRAIN.
        :param body: Request body.
        :return: Results in JSONresponse.
        """
        warnings.warn('this interface will be retired soon',
                      DeprecationWarning)
        if mode.upper() not in RequestType.__members__:
            return error(reason=f'unsupported mode {mode}', status_code=405)

        if 'data' not in body:
            return error('"data" field is empty', 406)

        body['mode'] = RequestType.from_string(mode)
        from .....clients import BaseClient

        BaseClient.add_default_kwargs(body)
        req_iter = request_generator(**body)
        results = await get_result_in_json(req_iter=req_iter)
        return JSONResponse(content=results[0], status_code=200)

    async def get_result_in_json(req_iter):
        """
        Convert message to JSON data.

        :param req_iter: Request iterator
        :return: Results in JSON format
        """
        return [
            MessageToDict(k)
            async for k in servicer.Call(request_iterator=req_iter,
                                         context=None)
        ]

    @app.post(path='/index',
              summary='Index documents into Jina',
              tags=['CRUD'])
    async def index_api(body: JinaIndexRequestModel):
        """
        Index API to index documents.

        :param body: index request.
        :return: Response of the results.
        """
        from .....clients import BaseClient

        bd = body.dict()
        bd['mode'] = RequestType.INDEX
        BaseClient.add_default_kwargs(bd)
        return StreamingResponse(result_in_stream(request_generator(**bd)),
                                 media_type='application/json')

    @app.post(path='/search',
              summary='Search documents from Jina',
              tags=['CRUD'])
    async def search_api(body: JinaSearchRequestModel):
        """
        Search API to search documents.

        :param body: search request.
        :return: Response of the results.
        """
        from .....clients import BaseClient

        bd = body.dict()
        bd['mode'] = RequestType.SEARCH
        BaseClient.add_default_kwargs(bd)
        return StreamingResponse(result_in_stream(request_generator(**bd)),
                                 media_type='application/json')

    @app.put(path='/update', summary='Update documents in Jina', tags=['CRUD'])
    async def update_api(body: JinaUpdateRequestModel):
        """
        Update API to update documents.

        :param body: update request.
        :return: Response of the results.
        """
        from .....clients import BaseClient

        bd = body.dict()
        bd['mode'] = RequestType.UPDATE
        BaseClient.add_default_kwargs(bd)
        return StreamingResponse(result_in_stream(request_generator(**bd)),
                                 media_type='application/json')

    @app.delete(path='/delete',
                summary='Delete documents in Jina',
                tags=['CRUD'])
    async def delete_api(body: JinaDeleteRequestModel):
        """
        Delete API to delete documents.

        :param body: delete request.
        :return: Response of the results.
        """
        from .....clients import BaseClient

        bd = body.dict()
        bd['mode'] = RequestType.DELETE
        BaseClient.add_default_kwargs(bd)
        return StreamingResponse(result_in_stream(request_generator(**bd)),
                                 media_type='application/json')

    async def result_in_stream(req_iter):
        """
        Streams results from AsyncPrefetchCall as json

        :param req_iter: request iterator
        :yield: result
        """
        async for k in servicer.Call(request_iterator=req_iter, context=None):
            yield MessageToJson(k)

    @app.websocket_route(path='/stream')
    class StreamingEndpoint(WebSocketEndpoint):
        """
        :meth:`handle_receive()`
            Await a message on :meth:`websocket.receive()`
            Send the message to zmqlet via :meth:`zmqlet.send_message()` and await
        :meth:`handle_send()`
            Await a message on :meth:`zmqlet.recv_message()`
            Send the message back to client via :meth:`websocket.send()` and await
        :meth:`dispatch()`
            Awaits on concurrent tasks :meth:`handle_receive()` & :meth:`handle_send()`
            This makes sure gateway is nonblocking
        Await exit strategy:
            :meth:`handle_receive()` keeps track of num_requests received
            :meth:`handle_send()` keeps track of num_responses sent
            Client sends a final message: `bytes(True)` to indicate request iterator is empty
            Server exits out of await when `(num_requests == num_responses != 0 and is_req_empty)`
        """

        encoding = None

        def __init__(self, scope: 'Scope', receive: 'Receive',
                     send: 'Send') -> None:
            super().__init__(scope, receive, send)
            self.args = args
            self.name = args.name or self.__class__.__name__
            self._id = random_identity()
            self.client_encoding = None

        async def dispatch(self) -> None:
            """Awaits on concurrent tasks :meth:`handle_receive()` & :meth:`handle_send()`"""
            websocket = WebSocket(self.scope,
                                  receive=self.receive,
                                  send=self.send)
            await self.on_connect(websocket)
            close_code = status.WS_1000_NORMAL_CLOSURE

            await asyncio.gather(
                self.handle_receive(websocket=websocket,
                                    close_code=close_code), )

        async def on_connect(self, websocket: WebSocket) -> None:
            """
            Await the websocket to accept and log the information.

            :param websocket: connected websocket
            """
            # TODO(Deepankar): To enable multiple concurrent clients,
            # Register each client - https://fastapi.tiangolo.com/advanced/websockets/#handling-disconnections-and-multiple-clients
            # And move class variables to instance variable
            await websocket.accept()
            self.client_info = f'{websocket.client.host}:{websocket.client.port}'
            logger.success(
                f'Client {self.client_info} connected to stream requests via websockets'
            )

        async def handle_receive(self, websocket: WebSocket,
                                 close_code: int) -> None:
            """
            Await a message on :meth:`websocket.receive()`
            Send the message to zmqlet via :meth:`zmqlet.send_message()` and await

            :param websocket: WebSocket connection between clinet sand server.
            :param close_code: close code
            """
            def handle_route(msg: 'Message') -> 'Request':
                """
                Add route information to `message`.

                :param msg: receive message
                :return: message response with route information
                """
                msg.add_route(self.name, self._id)
                return msg.response

            try:
                while True:
                    message = await websocket.receive()
                    if message['type'] == 'websocket.receive':
                        data = await self.decode(websocket, message)
                        if data == bytes(True):
                            await asyncio.sleep(0.1)
                            continue
                        await zmqlet.send_message(
                            Message(None, Request(data), 'gateway',
                                    **vars(self.args)))
                        response = await zmqlet.recv_message(
                            callback=handle_route)
                        if self.client_encoding == 'bytes':
                            await websocket.send_bytes(
                                response.SerializeToString())
                        else:
                            await websocket.send_json(response.json())
                    elif message['type'] == 'websocket.disconnect':
                        close_code = int(
                            message.get('code', status.WS_1000_NORMAL_CLOSURE))
                        break
            except Exception as exc:
                close_code = status.WS_1011_INTERNAL_ERROR
                logger.error(f'Got an exception in handle_receive: {exc!r}')
                raise
            finally:
                await self.on_disconnect(websocket, close_code)

        async def decode(self, websocket: WebSocket, message: Message) -> Any:
            """
            Decode the text or bytes format `message`

            :param websocket: WebSocket connection.
            :param message: Jina `Message`.
            :return: decoded message.
            """
            if 'text' in message or 'json' in message:
                self.client_encoding = 'text'

            if 'bytes' in message:
                self.client_encoding = 'bytes'

            return await super().decode(websocket, message)

        async def on_disconnect(self, websocket: WebSocket,
                                close_code: int) -> None:
            """
            Log the information when client is disconnected.

            :param websocket: disconnected websocket
            :param close_code: close code
            """
            logger.info(f'Client {self.client_info} got disconnected!')

    return app
Beispiel #24
0
def test_fastapi_features(serve_instance):
    app = FastAPI(openapi_url="/my_api.json")

    @app.on_event("startup")
    def inject_state():
        app.state.state_one = "app.state"

    @app.middleware("http")
    async def add_process_time_header(request: Request, call_next):
        start_time = time.time()
        response = await call_next(request)
        process_time = time.time() - start_time
        response.headers["X-Process-Time"] = str(process_time)
        return response

    class Nested(BaseModel):
        val: int

    class BodyType(BaseModel):
        name: str
        price: float = Field(None, gt=1.0, description="High price!")
        nests: Nested

    class RespModel(BaseModel):
        ok: bool
        vals: List[Any]
        file_path: str

    async def yield_db():
        yield "db"

    async def common_parameters(q: Optional[str] = None):
        return {"q": q}

    @app.exception_handler(ValueError)
    async def custom_handler(_: Request, exc: ValueError):
        return JSONResponse(status_code=500,
                            content={
                                "custom_error": "true",
                                "message": str(exc)
                            })

    def run_background(background_tasks: BackgroundTasks):
        path = tempfile.mktemp()

        def write_to_file(p):
            with open(p, "w") as f:
                f.write("hello")

        background_tasks.add_task(write_to_file, path)
        return path

    app.add_middleware(CORSMiddleware, allow_origins="*")

    @app.get("/{path_arg}", response_model=RespModel, status_code=201)
    async def func(
            path_arg: str,
            query_arg: str,
            body_val: BodyType,
            backgrounds_tasks: BackgroundTasks,
            do_error: bool = False,
            query_arg_valid: Optional[str] = Query(None, min_length=3),
            cookie_arg: Optional[str] = Cookie(None),
            user_agent: Optional[str] = Header(None),
            commons: dict = Depends(common_parameters),
            db=Depends(yield_db),
    ):
        if do_error:
            raise ValueError("bad input")

        path = run_background(backgrounds_tasks)

        return RespModel(
            ok=True,
            vals=[
                path_arg,
                query_arg,
                body_val.price,
                body_val.nests.val,
                do_error,
                query_arg_valid,
                cookie_arg,
                user_agent.split("/")[0],  # returns python-requests
                commons,
                db,
                app.state.state_one,
            ],
            file_path=path,
        )

    router = APIRouter(prefix="/prefix")

    @router.get("/subpath")
    def router_path():
        return "ok"

    app.include_router(router)

    @serve.deployment(name="fastapi")
    @serve.ingress(app)
    class Worker:
        pass

    Worker.deploy()

    url = "http://localhost:8000/fastapi"
    resp = requests.get(f"{url}/")
    assert resp.status_code == 404
    assert "x-process-time" in resp.headers

    resp = requests.get(f"{url}/my_api.json")
    assert resp.status_code == 200
    assert resp.json()  # it returns a well-formed json.

    resp = requests.get(f"{url}/docs")
    assert resp.status_code == 200
    assert "<!DOCTYPE html>" in resp.text

    resp = requests.get(f"{url}/redoc")
    assert resp.status_code == 200
    assert "<!DOCTYPE html>" in resp.text

    resp = requests.get(f"{url}/path_arg")
    assert resp.status_code == 422  # Malformed input

    resp = requests.get(f"{url}/path_arg",
                        json={
                            "name": "serve",
                            "price": 12,
                            "nests": {
                                "val": 1
                            }
                        },
                        params={
                            "query_arg": "query_arg",
                            "query_arg_valid": "at-least-three-chars",
                            "q": "common_arg",
                        })
    assert resp.status_code == 201, resp.text
    assert resp.json()["ok"]
    assert resp.json()["vals"] == [
        "path_arg",
        "query_arg",
        12.0,
        1,
        False,
        "at-least-three-chars",
        None,
        "python-requests",
        {
            "q": "common_arg"
        },
        "db",
        "app.state",
    ]
    assert open(resp.json()["file_path"]).read() == "hello"

    resp = requests.get(f"{url}/path_arg",
                        json={
                            "name": "serve",
                            "price": 12,
                            "nests": {
                                "val": 1
                            }
                        },
                        params={
                            "query_arg": "query_arg",
                            "query_arg_valid": "at-least-three-chars",
                            "q": "common_arg",
                            "do_error": "true"
                        })
    assert resp.status_code == 500
    assert resp.json()["custom_error"] == "true"

    resp = requests.get(f"{url}/prefix/subpath")
    assert resp.status_code == 200

    resp = requests.get(f"{url}/docs",
                        headers={
                            "Access-Control-Request-Method": "GET",
                            "Origin": "https://googlebot.com"
                        })
    assert resp.headers["access-control-allow-origin"] == "*", resp.headers
Beispiel #25
0
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

from app.api import predict, viz

app = FastAPI(
    title='DS AirBnB API',
    description='API containing property data and pickled models',
    version='0.1',
    docs_url='/',
)

app.include_router(predict.router)
app.include_router(viz.router)

app.add_middleware(
    CORSMiddleware,
    allow_origin_regex='https?://.*',
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

if __name__ == '__main__':
    uvicorn.run(app)
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware

app = FastAPI()

app.add_middleware(GZipMiddleware, minimum_size=1000)


@app.get("/")
async def main():
    return "somebigcontent"
Beispiel #27
0
    config={
        "sampler": {
            "type": "const",
            "param": 1
        },
        "logging": True,
        "local_agent": {
            "reporting_host": "localhost"
        },
    },
    scope_manager=ContextVarsScopeManager(),
    service_name=f"{service_name}_opentracing",
)
tracer_opentracing = opentracing_config.initialize_tracer()
install_all_patches()
app.add_middleware(StarletteTracingMiddleWare,
                   tracer=shim)  # Использовать opentelemetry
app.add_middleware(StarletteTracingMiddleWare,
                   tracer=tracer_opentracing)  # Использовать opentracing

# =========


class ServiceIn(BaseModel):
    name: str
    stage: str
    host: str
    port: int
    active: bool


class Service(BaseModel):
Beispiel #28
0
from typing import List, Dict

import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from pydantic import BaseModel


app = FastAPI()
app.add_middleware(
    CORSMiddleware, allow_origins=["http://localhost:3000"], allow_methods=["*"]
)

fake_db = {
    "teams": {
        "team_a": {1: 120, 2: 200, 3: 225, 4: 300},
        "team_b": {1: 100, 2: 150, 3: 200, 4: 240},
        "team_c": {1: 75, 2: 90, 3: 113, 4: 180},
    },
    "competitions": {
        "comp_a": [
            {"server_name": "server_a", "team_name": "team_a"},
            {"server_name": "server_a", "team_name": "team_b"},
        ],
        "comp_b": [
            {"server_name": "server_b", "team_name": "team_a"},
            {"server_name": "server_c", "team_name": "team_c"},
        ],
        "comp_c": [
            {"server_name": "server_b", "team_name": "team_b"},
            {"server_name": "server_a", "team_name": "team_c"},
from models.user_models import UserIn, UserOut
from models.transaction_models import TransactionIn, TransactionOut
import datetime
from fastapi import FastAPI
from fastapi import HTTPException
api = FastAPI()
from fastapi.middleware.cors import CORSMiddleware
origins = [
    "http://localhost.tiangolo.com", "https://localhost.tiangolo.com",
    "http://localhost", "http://*****:*****@api.post("/user/auth/")
async def auth_user(user_in: UserIn):
    user_in_db = get_user(user_in.username)
    if user_in_db == None:
        raise HTTPException(status_code=404, detail="El usuario no existe")
    if user_in_db.password != user_in.password:
        return {"Autenticado": False}
    return {"Autenticado": True}

ner.load_generator()
nlp = SentimentPredictor()
nlp.load_generator()


class Hansard_Predictor(BaseModel):
    hansard_text: str
    output: dict = None


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
def read_root():
    return {"Singapore Parliament Hansard NLP": "API"}


@app.post("/ner/")
async def ner_inference(hansard: Hansard_Predictor):
    hansard.output = ner.generate_entities(hansard.hansard_text)
    return {"output": hansard.output}