Example #1
0
def test_create_collection_already_exists(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    data = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(data, request=MockStarletteRequest)

    with pytest.raises(ConflictError):
        postgres_transactions.create_collection(data, request=MockStarletteRequest)
Example #2
0
def test_get_collection(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    data = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(data, request=MockStarletteRequest)
    coll = postgres_core.get_collection(data.id, request=MockStarletteRequest)
    assert data.dict(exclude={"links"}) == coll.dict(exclude={"links"})
    assert coll.id == data.id
Example #3
0
def test_update_collection(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    data = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(data, request=MockStarletteRequest)

    data.keywords.append("new keyword")
    postgres_transactions.update_collection(data, request=MockStarletteRequest)

    coll = postgres_core.get_collection(data.id, request=MockStarletteRequest)
    assert "new keyword" in coll.keywords
Example #4
0
def test_create_item(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    coll = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(coll, request=MockStarletteRequest)
    item = Item.parse_obj(load_test_data("test_item.json"))
    postgres_transactions.create_item(item, request=MockStarletteRequest)
    resp = postgres_core.get_item(item.id, request=MockStarletteRequest)
    assert item.dict(
        exclude={"links": ..., "properties": {"created", "updated"}}
    ) == resp.dict(exclude={"links": ..., "properties": {"created", "updated"}})
Example #5
0
def test_delete_collection(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    data = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(data, request=MockStarletteRequest)

    deleted = postgres_transactions.delete_collection(
        data.id, request=MockStarletteRequest)

    with pytest.raises(NotFoundError):
        postgres_core.get_collection(deleted.id, request=MockStarletteRequest)
Example #6
0
def cleanup(postgres_core: CoreCrudClient,
            postgres_transactions: TransactionsClient):
    yield
    collections = postgres_core.all_collections(request=MockStarletteRequest)
    for coll in collections:
        if coll.id.split("-")[0] == "test":
            # Delete the items
            items = postgres_core.item_collection(coll.id,
                                                  limit=100,
                                                  request=MockStarletteRequest)
            for feat in items.features:
                postgres_transactions.delete_item(feat.id,
                                                  request=MockStarletteRequest)

            # Delete the collection
            postgres_transactions.delete_collection(
                coll.id, request=MockStarletteRequest)
Example #7
0
def test_get_collection_items(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    coll = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(coll, request=MockStarletteRequest)

    item = Item.parse_obj(load_test_data("test_item.json"))

    for _ in range(5):
        item.id = str(uuid.uuid4())
        postgres_transactions.create_item(item, request=MockStarletteRequest)

    fc = postgres_core.item_collection(coll.id, request=MockStarletteRequest)
    assert len(fc.features) == 5

    for item in fc.features:
        assert item.collection == coll.id
Example #8
0
def test_bulk_item_insert(
    postgres_transactions: TransactionsClient,
    postgres_bulk_transactions: BulkTransactionsClient,
    load_test_data: Callable,
):
    coll = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(coll, request=MockStarletteRequest)

    item = Item.parse_obj(load_test_data("test_item.json"))

    items = []
    for _ in range(10):
        _item = item.dict()
        _item["id"] = str(uuid.uuid4())
        items.append(_item)

    postgres_bulk_transactions.bulk_item_insert(Items(items=items))

    for item in items:
        postgres_transactions.delete_item(item["id"],
                                          request=MockStarletteRequest)
Example #9
0
def api_client(db_session):
    return StacApi(
        settings=ApiSettings(),
        client=CoreCrudClient(session=db_session),
        extensions=[
            TransactionExtension(client=TransactionsClient(
                session=db_session)),
            ContextExtension(),
            SortExtension(),
            FieldsExtension(),
            QueryExtension(),
        ],
    )
Example #10
0
def postgres_transactions(reader_connection, writer_connection):
    with patch(
            "stac_api.clients.postgres.base.PostgresClient.writer_session",
            new_callable=PropertyMock,
    ) as mock_writer:
        mock_writer.return_value = writer_connection
        with patch(
                "stac_api.clients.postgres.base.PostgresClient.reader_session",
                new_callable=PropertyMock,
        ) as mock_reader:
            mock_reader.return_value = reader_connection
            client = TransactionsClient()
            yield client
Example #11
0
def test_delete_item(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    coll = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(coll, request=MockStarletteRequest)

    item = Item.parse_obj(load_test_data("test_item.json"))
    postgres_transactions.create_item(item, request=MockStarletteRequest)

    postgres_transactions.delete_item(item.id, request=MockStarletteRequest)

    with pytest.raises(NotFoundError):
        postgres_core.get_item(item.id, request=MockStarletteRequest)
Example #12
0
def test_update_item(
    postgres_core: CoreCrudClient,
    postgres_transactions: TransactionsClient,
    load_test_data: Callable,
):
    coll = Collection.parse_obj(load_test_data("test_collection.json"))
    postgres_transactions.create_collection(coll, request=MockStarletteRequest)

    item = Item.parse_obj(load_test_data("test_item.json"))
    postgres_transactions.create_item(item, request=MockStarletteRequest)

    item.properties.foo = "bar"
    postgres_transactions.update_item(item, request=MockStarletteRequest)

    updated_item = postgres_core.get_item(item.id, request=MockStarletteRequest)
    assert updated_item.properties.foo == "bar"
Example #13
0
    SortExtension,
    TilesExtension,
    TransactionExtension,
)
from stac_api.clients.postgres.core import CoreCrudClient
from stac_api.clients.postgres.session import Session
from stac_api.clients.postgres.transactions import (
    BulkTransactionsClient,
    TransactionsClient,
)
from stac_api.clients.tiles.ogc import TilesClient
from stac_api.config import PostgresSettings

settings = PostgresSettings()
session = Session(settings.reader_connection_string,
                  settings.writer_connection_string)
api = StacApi(
    settings=settings,
    extensions=[
        TransactionExtension(client=TransactionsClient(session=session)),
        BulkTransactionExtension(client=BulkTransactionsClient(
            session=session)),
        FieldsExtension(),
        QueryExtension(),
        SortExtension(),
        TilesExtension(TilesClient(session=session)),
    ],
    client=CoreCrudClient(session=session),
)
app = api.app
Example #14
0
def postgres_transactions(db_session):
    return TransactionsClient(
        session=db_session,
        item_table=database.Item,
        collection_table=database.Collection,
    )
Example #15
0
def create_app(settings: ApiSettings) -> FastAPI:
    """Create a FastAPI app"""
    paging_client = PaginationTokenClient()
    core_client = CoreCrudClient(pagination_client=paging_client)

    app = FastAPI()
    inject_settings(settings)

    app.debug = settings.debug
    app.include_router(
        create_core_router(core_client, settings), tags=["Core Endpoints"],
        dependencies=[Depends(oauth2_scheme)]
    )
    add_exception_handlers(app, DEFAULT_STATUS_CODES)

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_headers=["*"],
    )

    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(request: Request,
                                           exc: RequestValidationError):
        return JSONResponse(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            content=jsonable_encoder(
                {
                    "detail": exc.errors(),
                    "query_params": request.query_params,
                    "path_params": request.path_params,
                }
            ),
        )

    if settings.api_extension_is_enabled(ApiExtensions.transaction):
        transaction_client = TransactionsClient()
        app.include_router(
            create_transactions_router(transaction_client, settings),
            tags=["Transaction Extension"], dependencies=[Depends(oauth2_scheme)]
        )

    if settings.add_on_is_enabled(AddOns.tiles):
        tiles_client = TilesClient()
        app.add_api_route(
            name="Get OGC Tiles Resource",
            path="/collections/{collectionId}/items/{itemId}/tiles",
            response_model=TileSetResource,
            response_model_exclude_none=True,
            response_model_exclude_unset=True,
            methods=["GET"],
            endpoint=create_endpoint_with_depends(tiles_client.get_item_tiles, ItemUri),
            tags=["OGC Tiles"],
            dependencies=[Depends(oauth2_scheme)]
        )
        app.include_router(create_tiles_router(), prefix="/titiler", tags=["Titiler"], dependencies=[Depends(oauth2_scheme)])

    config_openapi(app, settings)

    @app.on_event("startup")
    async def on_startup():
        """Create database engines and sessions on startup"""
        app.state.ENGINE_READER = create_engine(
            settings.reader_connection_string, echo=settings.debug
        )
        app.state.ENGINE_WRITER = create_engine(
            settings.writer_connection_string, echo=settings.debug
        )
        app.state.DB_READER = sessionmaker(
            autocommit=False, autoflush=False, bind=app.state.ENGINE_READER
        )
        app.state.DB_WRITER = sessionmaker(
            autocommit=False, autoflush=False, bind=app.state.ENGINE_WRITER
        )

    @app.on_event("shutdown")
    async def on_shutdown():
        """Dispose of database engines and sessions on app shutdown"""
        app.state.ENGINE_READER.dispose()
        app.state.ENGINE_WRITER.dispose()

    @app.middleware("http")
    async def create_db_connection(request: Request, call_next):
        """Create a new database connection for each request"""
        if "titiler" in str(request.url):
            return await call_next(request)
        reader = request.app.state.DB_READER()
        writer = request.app.state.DB_WRITER()
        READER.set(reader)
        WRITER.set(writer)
        resp = await call_next(request)
        reader.close()
        writer.close()
        return resp

    @app.post("/login")
    async def login(body: Login):
        try:
            tokens = await get_tokens(body.username, body.password)

            return tokens
        except Exception as exception:
            raise HTTPException(status_code=400, detail=f"{exception}")

    @app.post("/token")
    async def get_token(form_data: OAuth2PasswordRequestForm = Depends()):
        try:
            username = form_data.username
            password = form_data.password
            tokens = await get_tokens(username, password)
            access_token = tokens["access_token"]

            return {"access_token": access_token, "token_type": "bearer"}
        except Exception as exception:
            raise HTTPException(status_code=400, detail=f"{exception}")

    mgmt_router = APIRouter()

    @mgmt_router.get("/_mgmt/ping")
    async def ping():
        """Liveliness/readiness probe"""
        return {"message": "PONG"}

    app.include_router(mgmt_router, tags=["Liveliness/Readiness"])

    return app