Example #1
0
    def get_database_model(cls, schema: schemas.Item) -> dict:
        """Decompose pydantic model to data model."""
        indexed_fields = {}
        for field in Settings.get().indexed_fields:
            # Use getattr to accommodate extension namespaces
            field_value = getattr(schema.properties, field)
            if field == "datetime":
                field_value = datetime.strptime(field_value, DATETIME_RFC339)
            indexed_fields[field.split(":")[-1]] = field_value

        # Exclude indexed fields from the properties jsonb field
        properties = schema.properties.dict(
            exclude=set(Settings.get().indexed_fields))
        now = datetime.utcnow().strftime(DATETIME_RFC339)
        if not properties["created"]:
            properties["created"] = now
        properties["updated"] = now

        return dict(
            collection_id=schema.collection,
            geometry=ga.shape.from_shape(shape(schema.geometry), 4326),
            properties=properties,
            **indexed_fields,
            **schema.dict(
                exclude_none=True,
                exclude=set(Settings().get().forbidden_fields
                            | {"geometry", "properties", "collection"}),
            ))
Example #2
0
    def __attrs_post_init__(self):
        """Post-init hook.

        Responsible for setting up the application upon instantiation of the class.

        Returns:
            None
        """
        # inject settings
        self.client.extensions = self.extensions

        fields_ext = self.get_extension(FieldsExtension)
        if fields_ext:
            self.settings.default_includes = fields_ext.default_includes

        Settings.set(self.settings)

        self.register_core()
        # register extensions
        for ext in self.extensions:
            ext.register(self.app)

        # add health check
        self.add_health_check()

        # register exception handlers
        add_exception_handlers(self.app, status_codes=self.exceptions)

        # customize openapi
        self.app.openapi = self.customize_openapi
Example #3
0
 def __init__(self, obj: Any):
     """Decompose orm model to pydantic model."""
     properties = obj.properties.copy()
     for field in Settings.get().indexed_fields:
         # Use getattr to accommodate extension namespaces
         field_value = getattr(obj, field.split(":")[-1])
         if field == "datetime":
             field_value = field_value.strftime(DATETIME_RFC339)
         properties[field] = field_value
     # Create inferred links
     item_links = ItemLinks(
         collection_id=obj.collection_id, base_url=obj.base_url, item_id=obj.id
     ).create_links()
     # Resolve existing links
     if obj.links:
         item_links += resolve_links(obj.links, obj.base_url)
     db_model = obj.__class__(
         id=obj.id,
         stac_version=obj.stac_version,
         geometry=self.decode_geom(obj.geometry),
         bbox=obj.bbox,
         properties=properties,
         assets=obj.assets,
         collection_id=obj.collection_id,
         datetime=obj.datetime,
         links=item_links,
         stac_extensions=obj.stac_extensions,
     )
     db_model.type = "Feature"
     db_model.collection = db_model.collection_id
     super().__init__(db_model)
Example #4
0
    def stac_to_db(cls,
                   stac_data: TypedDict,
                   exclude_geometry: bool = False) -> database.Item:
        """Transform stac item to database model."""
        indexed_fields = {}
        for field in Settings.get().indexed_fields:
            # Use getattr to accommodate extension namespaces
            field_value = stac_data["properties"][field]
            if field == "datetime":
                field_value = datetime.strptime(field_value, DATETIME_RFC339)
            indexed_fields[field.split(":")[-1]] = field_value

            # TODO: Exclude indexed fields from the properties jsonb field to prevent duplication

            now = datetime.utcnow().strftime(DATETIME_RFC339)
            if "created" not in stac_data["properties"]:
                stac_data["properties"]["created"] = now
            stac_data["properties"]["updated"] = now

        return database.Item(
            id=stac_data["id"],
            collection_id=stac_data["collection"],
            stac_version=stac_data["stac_version"],
            stac_extensions=stac_data.get("stac_extensions"),
            geometry=json.dumps(stac_data["geometry"]),
            bbox=stac_data["bbox"],
            properties=stac_data["properties"],
            assets=stac_data["assets"],
            **indexed_fields,
        )
Example #5
0
    def __attrs_post_init__(self):
        """Post-init hook.

        Responsible for setting up the application upon instantiation of the class.

        Returns:
            None
        """
        # inject settings
        self.client.extensions = self.extensions
        self.client.stac_version = self.stac_version
        self.client.title = self.title
        self.client.description = self.description

        fields_ext = self.get_extension(FieldsExtension)
        if fields_ext:
            self.settings.default_includes = fields_ext.default_includes

        Settings.set(self.settings)
        self.app.state.settings = self.settings

        # Register core STAC endpoints
        self.register_core()
        self.app.include_router(self.router)

        # register extensions
        for ext in self.extensions:
            ext.register(self.app)

        # add health check
        self.add_health_check()

        # register exception handlers
        add_exception_handlers(self.app, status_codes=self.exceptions)

        # customize openapi
        self.app.openapi = self.customize_openapi

        # add middlewares
        for middleware in self.middlewares:
            self.app.add_middleware(middleware)
Example #6
0
    def filter_fields(self) -> Dict:
        """Create pydantic include/exclude expression.

        Create dictionary of fields to include/exclude on model export based on the included and excluded fields passed
        to the API
        Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
        """
        # Always include default_includes, even if they
        # exist in the exclude list.
        include = (self.include or set()) - (self.exclude or set())
        include |= Settings.get().default_includes or set()

        return {
            "include": self._get_field_dict(include),
            "exclude": self._get_field_dict(self.exclude),
        }
Example #7
0
    def db_to_stac(cls, db_model: database.Item,
                   base_url: str) -> stac_types.Item:
        """Transform database model to stac item."""
        properties = db_model.properties.copy()
        indexed_fields = Settings.get().indexed_fields
        for field in indexed_fields:
            # Use getattr to accommodate extension namespaces
            field_value = getattr(db_model, field.split(":")[-1])
            if field == "datetime":
                field_value = field_value.strftime(DATETIME_RFC339)
            properties[field] = field_value
        item_id = db_model.id
        collection_id = db_model.collection_id
        item_links = ItemLinks(collection_id=collection_id,
                               item_id=item_id,
                               base_url=base_url).create_links()

        db_links = db_model.links
        if db_links:
            item_links += resolve_links(db_links, base_url)

        stac_extensions = db_model.stac_extensions or []

        # The custom geometry we are using emits geojson if the geometry is bound to the database
        # Otherwise it will return a geoalchemy2 WKBElement
        # TODO: It's probably best to just remove the custom geometry type
        geometry = db_model.geometry
        if isinstance(geometry, ga.elements.WKBElement):
            geometry = ga.shape.to_shape(geometry).__geo_interface__
        if isinstance(geometry, str):
            geometry = json.loads(geometry)

        return stac_types.Item(
            type="Feature",
            stac_version=db_model.stac_version,
            stac_extensions=stac_extensions,
            id=db_model.id,
            collection=db_model.collection_id,
            geometry=geometry,
            bbox=[float(x) for x in db_model.bbox],
            properties=properties,
            links=item_links,
            assets=db_model.assets,
        )
Example #8
0
    def post_search(self, search_request: STACSearch,
                    **kwargs) -> Dict[str, Any]:
        """POST search catalog."""
        with self.session.reader.context_session() as session:
            token = (self.get_token(search_request.token)
                     if search_request.token else False)
            query = session.query(self.item_table)

            # Filter by collection
            count = None
            if search_request.collections:
                query = query.join(self.collection_table).filter(
                    sa.or_(*[
                        self.collection_table.id == col_id
                        for col_id in search_request.collections
                    ]))

            # Sort
            if search_request.sortby:
                sort_fields = [
                    getattr(self.item_table.get_field(sort.field),
                            sort.direction.value)()
                    for sort in search_request.sortby
                ]
                sort_fields.append(self.item_table.id)
                query = query.order_by(*sort_fields)
            else:
                # Default sort is date
                query = query.order_by(self.item_table.datetime.desc(),
                                       self.item_table.id)

            # Ignore other parameters if ID is present
            if search_request.ids:
                id_filter = sa.or_(
                    *[self.item_table.id == i for i in search_request.ids])
                items = query.filter(id_filter).order_by(self.item_table.id)
                page = get_page(items,
                                per_page=search_request.limit,
                                page=token)
                if self.extension_is_enabled(ContextExtension):
                    count = len(search_request.ids)
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            else:
                # Spatial query
                poly = None
                if search_request.intersects is not None:
                    poly = shape(search_request.intersects)
                elif search_request.bbox:
                    poly = ShapelyPolygon.from_bounds(*search_request.bbox)

                if poly:
                    filter_geom = ga.shape.from_shape(poly, srid=4326)
                    query = query.filter(
                        ga.func.ST_Intersects(self.item_table.geometry,
                                              filter_geom))

                # Temporal query
                if search_request.datetime:
                    # Two tailed query (between)
                    if ".." not in search_request.datetime:
                        query = query.filter(
                            self.item_table.datetime.between(
                                *search_request.datetime))
                    # All items after the start date
                    if search_request.datetime[0] != "..":
                        query = query.filter(self.item_table.datetime >=
                                             search_request.datetime[0])
                    # All items before the end date
                    if search_request.datetime[1] != "..":
                        query = query.filter(self.item_table.datetime <=
                                             search_request.datetime[1])

                # Query fields
                if search_request.query:
                    for (field_name, expr) in search_request.query.items():
                        field = self.item_table.get_field(field_name)
                        for (op, value) in expr.items():
                            query = query.filter(op.operator(field, value))

                if self.extension_is_enabled(ContextExtension):
                    count_query = query.statement.with_only_columns(
                        [func.count()]).order_by(None)
                    count = query.session.execute(count_query).scalar()
                page = get_page(query,
                                per_page=search_request.limit,
                                page=token)
                # Create dynamic attributes for each page
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            links = []
            if page.next:
                links.append(
                    PaginationLink(
                        rel=Relations.next,
                        type="application/geo+json",
                        href=f"{kwargs['request'].base_url}search",
                        method="POST",
                        body={"token": page.next},
                        merge=True,
                    ))
            if page.previous:
                links.append(
                    PaginationLink(
                        rel=Relations.previous,
                        type="application/geo+json",
                        href=f"{kwargs['request'].base_url}search",
                        method="POST",
                        body={"token": page.previous},
                        merge=True,
                    ))

            response_features = []
            filter_kwargs = {}
            if self.extension_is_enabled(FieldsExtension):
                if search_request.query is not None:
                    query_include: Set[str] = set([
                        k if k in Settings.get().indexed_fields else
                        f"properties.{k}" for k in search_request.query.keys()
                    ])
                    if not search_request.field.include:
                        search_request.field.include = query_include
                    else:
                        search_request.field.include.union(query_include)

                filter_kwargs = search_request.field.filter_fields

            xvals = []
            yvals = []
            for item in page:
                item.base_url = str(kwargs["request"].base_url)
                item_model = schemas.Item.from_orm(item)
                xvals += [item_model.bbox[0], item_model.bbox[2]]
                yvals += [item_model.bbox[1], item_model.bbox[3]]
                response_features.append(item_model.to_dict(**filter_kwargs))

        try:
            bbox = (min(xvals), min(yvals), max(xvals), max(yvals))
        except ValueError:
            bbox = None

        context_obj = None
        if self.extension_is_enabled(ContextExtension):
            context_obj = {
                "returned": len(page),
                "limit": search_request.limit,
                "matched": count,
            }

        return {
            "type": "FeatureCollection",
            "context": context_obj,
            "features": response_features,
            "links": links,
            "bbox": bbox,
        }
Example #9
0
    def post_search(self, search_request: SQLAlchemySTACSearch,
                    **kwargs) -> ItemCollection:
        """POST search catalog."""
        base_url = str(kwargs["request"].base_url)
        with self.session.reader.context_session() as session:
            token = (self.get_token(search_request.token)
                     if search_request.token else False)
            query = session.query(self.item_table)

            # Filter by collection
            count = None
            if search_request.collections:
                query = query.join(self.collection_table).filter(
                    sa.or_(*[
                        self.collection_table.id == col_id
                        for col_id in search_request.collections
                    ]))

            # Sort
            if search_request.sortby:
                sort_fields = [
                    getattr(
                        self.item_table.get_field(sort.field),
                        sort.direction.value,
                    )() for sort in search_request.sortby
                ]
                sort_fields.append(self.item_table.id)
                query = query.order_by(*sort_fields)
            else:
                # Default sort is date
                query = query.order_by(self.item_table.datetime.desc(),
                                       self.item_table.id)

            # Ignore other parameters if ID is present
            if search_request.ids:
                id_filter = sa.or_(
                    *[self.item_table.id == i for i in search_request.ids])
                items = query.filter(id_filter).order_by(self.item_table.id)
                page = get_page(items,
                                per_page=search_request.limit,
                                page=token)
                if self.extension_is_enabled(ContextExtension):
                    count = len(search_request.ids)
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            else:
                # Spatial query
                poly = None
                if search_request.intersects is not None:
                    poly = shape(search_request.intersects)
                elif search_request.bbox:
                    poly = ShapelyPolygon.from_bounds(*search_request.bbox)

                if poly:
                    filter_geom = ga.shape.from_shape(poly, srid=4326)
                    query = query.filter(
                        ga.func.ST_Intersects(self.item_table.geometry,
                                              filter_geom))

                # Temporal query
                if search_request.datetime:
                    # Two tailed query (between)
                    dts = search_request.datetime.split("/")
                    if ".." not in search_request.datetime:
                        query = query.filter(
                            self.item_table.datetime.between(*dts))
                    # All items after the start date
                    if dts[0] != "..":
                        query = query.filter(
                            self.item_table.datetime >= dts[0])
                    # All items before the end date
                    if dts[1] != "..":
                        query = query.filter(
                            self.item_table.datetime <= dts[1])

                # Query fields
                if search_request.query:
                    for (field_name, expr) in search_request.query.items():
                        field = self.item_table.get_field(field_name)
                        for (op, value) in expr.items():
                            query = query.filter(op.operator(field, value))

                if self.extension_is_enabled(ContextExtension):
                    count_query = query.statement.with_only_columns(
                        [func.count()]).order_by(None)
                    count = query.session.execute(count_query).scalar()
                page = get_page(query,
                                per_page=search_request.limit,
                                page=token)
                # Create dynamic attributes for each page
                page.next = (self.insert_token(
                    keyset=page.paging.bookmark_next)
                             if page.paging.has_next else None)
                page.previous = (self.insert_token(
                    keyset=page.paging.bookmark_previous)
                                 if page.paging.has_previous else None)

            links = []
            if page.next:
                links.append({
                    "rel": Relations.next.value,
                    "type": "application/geo+json",
                    "href": f"{kwargs['request'].base_url}search",
                    "method": "POST",
                    "body": {
                        "token": page.next
                    },
                    "merge": True,
                })
            if page.previous:
                links.append({
                    "rel": Relations.previous.value,
                    "type": "application/geo+json",
                    "href": f"{kwargs['request'].base_url}search",
                    "method": "POST",
                    "body": {
                        "token": page.previous
                    },
                    "merge": True,
                })

            response_features = []
            for item in page:
                response_features.append(
                    self.item_serializer.db_to_stac(item, base_url=base_url))

            # Use pydantic includes/excludes syntax to implement fields extension
            if self.extension_is_enabled(FieldsExtension):
                if search_request.query is not None:
                    query_include: Set[str] = set([
                        k if k in Settings.get().indexed_fields else
                        f"properties.{k}" for k in search_request.query.keys()
                    ])
                    if not search_request.field.include:
                        search_request.field.include = query_include
                    else:
                        search_request.field.include.union(query_include)

                filter_kwargs = search_request.field.filter_fields
                # Need to pass through `.json()` for proper serialization
                # of datetime
                response_features = [
                    json.loads(
                        stac_pydantic.Item(**feat).json(**filter_kwargs))
                    for feat in response_features
                ]

        context_obj = None
        if self.extension_is_enabled(ContextExtension):
            context_obj = {
                "returned": len(page),
                "limit": search_request.limit,
                "matched": count,
            }

        # TODO: return stac_extensions
        return ItemCollection(
            type="FeatureCollection",
            stac_version=STAC_VERSION,
            features=response_features,
            links=links,
            context=context_obj,
        )
Example #10
0
    BulkTransactionsClient,
    TransactionsClient,
)
from stac_fastapi.sqlalchemy.types.search import SQLAlchemySTACSearch
from stac_fastapi.types.config import Settings

DATA_DIR = os.path.join(os.path.dirname(__file__), "data")


class TestSettings(SqlalchemySettings):
    class Config:
        env_file = ".env.test"


settings = TestSettings()
Settings.set(settings)


@pytest.fixture(autouse=True)
def cleanup(postgres_core: CoreCrudClient,
            postgres_transactions: TransactionsClient):
    yield
    collections = postgres_core.all_collections(request=MockStarletteRequest)
    for coll in collections:
        if coll.id.split("-")[0] == "test":
            # Delete the items
            items = postgres_core.item_collection(coll.id,
                                                  limit=100,
                                                  request=MockStarletteRequest)
            for feat in items.features:
                postgres_transactions.delete_item(feat.id,