def test_delete_collection( postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient, load_test_data: Callable, ): data = Collection.parse_obj(load_test_data("test_collection.json")) postgres_transactions.create_collection(data, request=MockStarletteRequest) deleted = postgres_transactions.delete_collection( data.id, request=MockStarletteRequest) with pytest.raises(NotFoundError): postgres_core.get_collection(deleted.id, request=MockStarletteRequest)
def test_delete_item( postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient, load_test_data: Callable, ): coll = Collection.parse_obj(load_test_data("test_collection.json")) postgres_transactions.create_collection(coll, request=MockStarletteRequest) item = Item.parse_obj(load_test_data("test_item.json")) postgres_transactions.create_item(item, request=MockStarletteRequest) postgres_transactions.delete_item(item.id, request=MockStarletteRequest) with pytest.raises(NotFoundError): postgres_core.get_item(item.id, request=MockStarletteRequest)
def postgres_core(db_session): return CoreCrudClient( session=db_session, item_table=database.Item, collection_table=database.Collection, token_table=database.PaginationToken, )
def cleanup(postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient): yield collections = postgres_core.all_collections(request=MockStarletteRequest) for coll in collections: if coll.id.split("-")[0] == "test": # Delete the items items = postgres_core.item_collection(coll.id, limit=100, request=MockStarletteRequest) for feat in items.features: postgres_transactions.delete_item(feat.id, request=MockStarletteRequest) # Delete the collection postgres_transactions.delete_collection( coll.id, request=MockStarletteRequest)
def test_get_collection( postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient, load_test_data: Callable, ): data = Collection.parse_obj(load_test_data("test_collection.json")) postgres_transactions.create_collection(data, request=MockStarletteRequest) coll = postgres_core.get_collection(data.id, request=MockStarletteRequest) assert data.dict(exclude={"links"}) == coll.dict(exclude={"links"}) assert coll.id == data.id
def api_client(db_session): return StacApi( settings=ApiSettings(), client=CoreCrudClient(session=db_session), extensions=[ TransactionExtension(client=TransactionsClient( session=db_session)), ContextExtension(), SortExtension(), FieldsExtension(), QueryExtension(), ], )
def test_update_collection( postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient, load_test_data: Callable, ): data = Collection.parse_obj(load_test_data("test_collection.json")) postgres_transactions.create_collection(data, request=MockStarletteRequest) data.keywords.append("new keyword") postgres_transactions.update_collection(data, request=MockStarletteRequest) coll = postgres_core.get_collection(data.id, request=MockStarletteRequest) assert "new keyword" in coll.keywords
def test_create_item( postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient, load_test_data: Callable, ): coll = Collection.parse_obj(load_test_data("test_collection.json")) postgres_transactions.create_collection(coll, request=MockStarletteRequest) item = Item.parse_obj(load_test_data("test_item.json")) postgres_transactions.create_item(item, request=MockStarletteRequest) resp = postgres_core.get_item(item.id, request=MockStarletteRequest) assert item.dict( exclude={"links": ..., "properties": {"created", "updated"}} ) == resp.dict(exclude={"links": ..., "properties": {"created", "updated"}})
def postgres_core(reader_connection, writer_connection): with patch( "stac_api.clients.postgres.base.PostgresClient.writer_session", new_callable=PropertyMock, ) as mock_writer: mock_writer.return_value = writer_connection with patch( "stac_api.clients.postgres.base.PostgresClient.reader_session", new_callable=PropertyMock, ) as mock_reader: mock_reader.return_value = reader_connection client = CoreCrudClient() yield client
def test_update_item( postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient, load_test_data: Callable, ): coll = Collection.parse_obj(load_test_data("test_collection.json")) postgres_transactions.create_collection(coll, request=MockStarletteRequest) item = Item.parse_obj(load_test_data("test_item.json")) postgres_transactions.create_item(item, request=MockStarletteRequest) item.properties.foo = "bar" postgres_transactions.update_item(item, request=MockStarletteRequest) updated_item = postgres_core.get_item(item.id, request=MockStarletteRequest) assert updated_item.properties.foo == "bar"
def test_get_collection_items( postgres_core: CoreCrudClient, postgres_transactions: TransactionsClient, load_test_data: Callable, ): coll = Collection.parse_obj(load_test_data("test_collection.json")) postgres_transactions.create_collection(coll, request=MockStarletteRequest) item = Item.parse_obj(load_test_data("test_item.json")) for _ in range(5): item.id = str(uuid.uuid4()) postgres_transactions.create_item(item, request=MockStarletteRequest) fc = postgres_core.item_collection(coll.id, request=MockStarletteRequest) assert len(fc.features) == 5 for item in fc.features: assert item.collection == coll.id
SortExtension, TilesExtension, TransactionExtension, ) from stac_api.clients.postgres.core import CoreCrudClient from stac_api.clients.postgres.session import Session from stac_api.clients.postgres.transactions import ( BulkTransactionsClient, TransactionsClient, ) from stac_api.clients.tiles.ogc import TilesClient from stac_api.config import PostgresSettings settings = PostgresSettings() session = Session(settings.reader_connection_string, settings.writer_connection_string) api = StacApi( settings=settings, extensions=[ TransactionExtension(client=TransactionsClient(session=session)), BulkTransactionExtension(client=BulkTransactionsClient( session=session)), FieldsExtension(), QueryExtension(), SortExtension(), TilesExtension(TilesClient(session=session)), ], client=CoreCrudClient(session=session), ) app = api.app
def create_app(settings: ApiSettings) -> FastAPI: """Create a FastAPI app""" paging_client = PaginationTokenClient() core_client = CoreCrudClient(pagination_client=paging_client) app = FastAPI() inject_settings(settings) app.debug = settings.debug app.include_router( create_core_router(core_client, settings), tags=["Core Endpoints"], dependencies=[Depends(oauth2_scheme)] ) add_exception_handlers(app, DEFAULT_STATUS_CODES) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["*"], ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content=jsonable_encoder( { "detail": exc.errors(), "query_params": request.query_params, "path_params": request.path_params, } ), ) if settings.api_extension_is_enabled(ApiExtensions.transaction): transaction_client = TransactionsClient() app.include_router( create_transactions_router(transaction_client, settings), tags=["Transaction Extension"], dependencies=[Depends(oauth2_scheme)] ) if settings.add_on_is_enabled(AddOns.tiles): tiles_client = TilesClient() app.add_api_route( name="Get OGC Tiles Resource", path="/collections/{collectionId}/items/{itemId}/tiles", response_model=TileSetResource, response_model_exclude_none=True, response_model_exclude_unset=True, methods=["GET"], endpoint=create_endpoint_with_depends(tiles_client.get_item_tiles, ItemUri), tags=["OGC Tiles"], dependencies=[Depends(oauth2_scheme)] ) app.include_router(create_tiles_router(), prefix="/titiler", tags=["Titiler"], dependencies=[Depends(oauth2_scheme)]) config_openapi(app, settings) @app.on_event("startup") async def on_startup(): """Create database engines and sessions on startup""" app.state.ENGINE_READER = create_engine( settings.reader_connection_string, echo=settings.debug ) app.state.ENGINE_WRITER = create_engine( settings.writer_connection_string, echo=settings.debug ) app.state.DB_READER = sessionmaker( autocommit=False, autoflush=False, bind=app.state.ENGINE_READER ) app.state.DB_WRITER = sessionmaker( autocommit=False, autoflush=False, bind=app.state.ENGINE_WRITER ) @app.on_event("shutdown") async def on_shutdown(): """Dispose of database engines and sessions on app shutdown""" app.state.ENGINE_READER.dispose() app.state.ENGINE_WRITER.dispose() @app.middleware("http") async def create_db_connection(request: Request, call_next): """Create a new database connection for each request""" if "titiler" in str(request.url): return await call_next(request) reader = request.app.state.DB_READER() writer = request.app.state.DB_WRITER() READER.set(reader) WRITER.set(writer) resp = await call_next(request) reader.close() writer.close() return resp @app.post("/login") async def login(body: Login): try: tokens = await get_tokens(body.username, body.password) return tokens except Exception as exception: raise HTTPException(status_code=400, detail=f"{exception}") @app.post("/token") async def get_token(form_data: OAuth2PasswordRequestForm = Depends()): try: username = form_data.username password = form_data.password tokens = await get_tokens(username, password) access_token = tokens["access_token"] return {"access_token": access_token, "token_type": "bearer"} except Exception as exception: raise HTTPException(status_code=400, detail=f"{exception}") mgmt_router = APIRouter() @mgmt_router.get("/_mgmt/ping") async def ping(): """Liveliness/readiness probe""" return {"message": "PONG"} app.include_router(mgmt_router, tags=["Liveliness/Readiness"]) return app