Beispiel #1
0
    def check_url(url_query: str) -> set:
        """Check parsed URL query part for parameters not followed by `=`.

        URL query parameters are considered to be split by ampersand (`&`)
        and semi-colon (`;`).

        Parameters:
            url_query: The raw urllib-parsed query part.

        Raises:
            BadRequest: If a query parameter does not come with a value.

        Returns:
            The set of individual query parameters and their values.

            This is mainly for testing and not actually neeeded by the middleware,
            since if the URL exhibits an invalid query part a `400 Bad Request`
            response will be returned.

        """
        queries_amp = set(url_query.split("&"))
        queries = set()
        for query in queries_amp:
            queries.update(set(query.split(";")))
        for query in queries:
            if "=" not in query and query != "":
                raise BadRequest(
                    detail=
                    "A query parameter without an equal sign (=) is not supported by this server"
                )
        return queries  # Useful for testing
Beispiel #2
0
    def handle_query_params(
        self, params: Union[EntryListingQueryParams,
                            SingleEntryQueryParams]) -> dict:
        """Parse and interpret the backend-agnostic query parameter models into a dictionary
        that can be used by the specific backend.

        Note:
            Currently this method returns the pymongo interpretation of the parameters,
            which will need modification for modified for other backends.

        Parameters:
            params (Union[EntryListingQueryParams, SingleEntryQueryParams]): The initialized query parameter model from the server.

        Raises:
            Forbidden: If too large of a page limit is provided.
            BadRequest: If an invalid request is made, e.g., with incorrect fields
                or response format.

        Returns:
            A dictionary representation of the query parameters, ready to be used by pymongo.

        """
        cursor_kwargs = {}

        if getattr(params, "filter", False):
            tree = self.parser.parse(params.filter)
            cursor_kwargs["filter"] = self.transformer.transform(tree)
        else:
            cursor_kwargs["filter"] = {}

        if (getattr(params, "response_format", False)
                and params.response_format != "json"):
            raise BadRequest(
                detail=
                f"Response format {params.response_format} is not supported, please use response_format='json'"
            )

        if getattr(params, "page_limit", False):
            limit = params.page_limit
            if limit > CONFIG.page_limit_max:
                raise Forbidden(
                    detail=
                    f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
                )
            cursor_kwargs["limit"] = limit
        else:
            cursor_kwargs["limit"] = CONFIG.page_limit

        cursor_kwargs["fields"] = self.all_fields
        cursor_kwargs["projection"] = [
            self.resource_mapper.alias_for(f) for f in self.all_fields
        ]

        if getattr(params, "sort", False):
            cursor_kwargs["sort"] = self.parse_sort_params(params.sort)

        if getattr(params, "page_offset", False):
            cursor_kwargs["skip"] = params.page_offset

        return cursor_kwargs
Beispiel #3
0
        def replace_only_filter(subdict: dict, prop: str, expr: dict):
            """Replace the magic key `"#only"` (added by this transformer) with an `$elemMatch`-based query.

            The first part of the query selects all the documents that contain any value that does not
            match any target values for the property `prop`.
            Subsequently, this selection is inverted, to get the documents that only have
            the allowed values.
            This inversion also selects documents with edge-case values such as null or empty lists;
            these are removed in the second part of the query that makes sure that only documents
            with lists that have at least one value are selected.

            """

            if "$and" not in subdict:
                subdict["$and"] = []

            if prop.startswith("relationships."):
                if prop not in (
                        "relationships.references.data.id",
                        "relationships.structures.data.id",
                ):
                    raise BadRequest(
                        f"Unable to query on unrecognised field {prop}.")
                first_part_prop = ".".join(prop.split(".")[:-1])
                subdict["$and"].append({
                    first_part_prop: {
                        "$not": {
                            "$elemMatch": {
                                "id": {
                                    "$nin": expr["#only"]
                                }
                            }
                        }
                    }
                })
                subdict["$and"].append(
                    {first_part_prop + ".0": {
                        "$exists": True
                    }})

            else:
                subdict["$and"].append(
                    {prop: {
                        "$not": {
                            "$elemMatch": {
                                "$nin": expr["#only"]
                            }
                        }
                    }})
                subdict["$and"].append({prop + ".0": {"$exists": True}})

            subdict.pop(prop)
            return subdict
Beispiel #4
0
 def check_url(url_query: str):
     """Check parsed URL query part for parameters not followed by `=`"""
     queries_amp = set(url_query.split("&"))
     queries = set()
     for query in queries_amp:
         queries.update(set(query.split(";")))
     for query in queries:
         if "=" not in query and query != "":
             raise BadRequest(
                 detail=
                 "A query parameter without an equal sign (=) is not supported by this server"
             )
     return queries  # Useful for testing
Beispiel #5
0
    def parse_sort_params(self, sort_params: str) -> Tuple[Tuple[str, int]]:
        """Handles any sort parameters passed to the collection,
        resolving aliases and dealing with any invalid fields.

        Raises:
            BadRequest: if an invalid sort is requested.

        Returns:
            A tuple of tuples containing the aliased field name and
            sort direction encoded as 1 (ascending) or -1 (descending).

        """
        sort_spec = []
        for field in sort_params.split(","):
            sort_dir = 1
            if field.startswith("-"):
                field = field[1:]
                sort_dir = -1
            aliased_field = self.resource_mapper.get_backend_field(field)
            sort_spec.append((aliased_field, sort_dir))

        unknown_fields = [
            field for field, _ in sort_spec
            if self.resource_mapper.get_optimade_field(field) not in
            self.all_fields
        ]

        if unknown_fields:
            error_detail = "Unable to sort on unknown field{} '{}'".format(
                "s" if len(unknown_fields) > 1 else "",
                "', '".join(unknown_fields),
            )

            # If all unknown fields are "other" provider-specific, then only provide a warning
            if all((re.match(r"_[a-z_0-9]+_[a-z_0-9]*", field)
                    and not field.startswith(f"_{self.provider_prefix}_"))
                   for field in unknown_fields):
                warnings.warn(error_detail, FieldValueNotRecognized)

            # Otherwise, if all fields are unknown, or some fields are unknown and do not
            # have other provider prefixes, then return 400: Bad Request
            else:
                raise BadRequest(detail=error_detail)

        # If at least one valid field has been provided for sorting, then use that
        sort_spec = tuple((field, sort_dir) for field, sort_dir in sort_spec
                          if field not in unknown_fields)

        return sort_spec
    def property(self, args: list) -> Any:
        """property: IDENTIFIER ( "." IDENTIFIER )*

        If this transformer has an associated mapper, the property
        will be compared to possible relationship entry types and
        for any supported provider prefixes. If there is a match,
        this rule will return a string and not a dereferenced
        [`Quantity`][optimade.filtertransformers.base_transformer.Quantity].

        Raises:
            BadRequest: If the property does not match any
                of the above rules.

        """
        quantity_name = str(args[0])

        # If the quantity name matches an entry type (indicating a relationship filter)
        # then simply return the quantity name; the inherited property
        # must then handle any further nested identifiers
        if self.mapper:
            if quantity_name in self.mapper.RELATIONSHIP_ENTRY_TYPES:
                return quantity_name

        if self.quantities and quantity_name not in self.quantities:
            # If the quantity is provider-specific, but does not match this provider,
            # then return the quantity name such that it can be treated as unknown.
            # If the prefix does not match another known provider, also emit a warning
            # If the prefix does match a known provider, do not return a warning.
            # Following [Handling unknown property names](https://github.com/Materials-Consortia/OPTIMADE/blob/master/optimade.rst#handling-unknown-property-names)
            if self.mapper and quantity_name.startswith("_"):
                prefix = quantity_name.split("_")[1]
                if prefix not in self.mapper.SUPPORTED_PREFIXES:
                    if prefix not in self.mapper.KNOWN_PROVIDER_PREFIXES:
                        warnings.warn(
                            UnknownProviderProperty(
                                f"Field {quantity_name!r} has an unrecognised prefix: this property has been treated as UNKNOWN."
                            ))

                    return quantity_name

            raise BadRequest(
                detail=
                f"'{quantity_name}' is not a known or searchable quantity")

        quantity = self.quantities.get(quantity_name, None)
        if quantity is None:
            quantity = self._quantity_type(name=str(quantity_name))

        return quantity
Beispiel #7
0
        def replace_str_id_with_objectid(subdict, prop, expr):
            from bson import ObjectId

            for operator in subdict[prop]:
                val = subdict[prop][operator]
                if operator not in ("$eq", "$ne"):
                    if self.mapper is not None:
                        prop = self.mapper.alias_of(prop)
                    raise BadRequest(
                        detail=
                        f"Operator not supported for query on field {prop!r}, can only test for equality"
                    )
                if isinstance(val, str):
                    subdict[prop][operator] = ObjectId(val)
            return subdict
    def parse(self, filter_: str) -> Tree:
        """Parse a filter string into a `lark.Tree`.

        Parameters:
            filter_: The filter string to parse.

        Raises:
            BadRequest: If the filter cannot be parsed.

        Returns:
            The parsed filter.

        """
        try:
            self.tree = self.lark.parse(filter_)
            self.filter = filter_
            return self.tree
        except Exception as exc:
            raise BadRequest(
                detail=f"Unable to parse filter {filter_}. Lark traceback: \n{exc}"
            ) from exc
Beispiel #9
0
    def find(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Tuple[Union[List[EntryResource], EntryResource, None], int, bool,
               Set[str], Set[str]]:
        """
        Fetches results and indicates if more data is available.

        Also gives the total number of data available in the absence of `page_limit`.
        See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams]
        for more information.

        Parameters:
            params: Entry listing URL query params.

        Returns:
            A tuple of various relevant values:
            (`results`, `data_returned`, `more_data_available`, `exclude_fields`, `include_fields`).

        """
        criteria = self.handle_query_params(params)
        single_entry = isinstance(params, SingleEntryQueryParams)
        response_fields = criteria.pop("fields")

        results, data_returned, more_data_available = self._run_db_query(
            criteria, single_entry)

        if single_entry:
            results = results[0] if results else None

            if data_returned > 1:
                raise NotFound(
                    detail=
                    f"Instead of a single entry, {data_returned} entries were found",
                )

        exclude_fields = self.all_fields - response_fields
        include_fields = (response_fields -
                          self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS)

        bad_optimade_fields = set()
        bad_provider_fields = set()
        for field in include_fields:
            if field not in self.resource_mapper.ALL_ATTRIBUTES:
                if field.startswith("_"):
                    if any(
                            field.startswith(f"_{prefix}_") for prefix in
                            self.resource_mapper.SUPPORTED_PREFIXES):
                        bad_provider_fields.add(field)
                else:
                    bad_optimade_fields.add(field)

        if bad_provider_fields:
            warnings.warn(
                message=
                f"Unrecognised field(s) for this provider requested in `response_fields`: {bad_provider_fields}.",
                category=UnknownProviderProperty,
            )

        if bad_optimade_fields:
            raise BadRequest(
                detail=
                f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}."
            )

        if results:
            results = self.resource_mapper.deserialize(results)

        return (
            results,
            data_returned,
            more_data_available,
            exclude_fields,
            include_fields,
        )
Beispiel #10
0
def get_included_relationships(
    results: Union[EntryResource, List[EntryResource]],
    ENTRY_COLLECTIONS: Dict[str, EntryCollection],
    include_param: List[str],
) -> Dict[str, List[EntryResource]]:
    """Filters the included relationships and makes the appropriate compound request
    to include them in the response.

    Parameters:
        results: list of returned documents.
        ENTRY_COLLECTIONS: dictionary containing collections to query, with key
            based on endpoint type.
        include_param: list of queried related resources that should be included in
            `included`.

    Returns:
        Dictionary with the same keys as ENTRY_COLLECTIONS, each containing the list
            of resource objects for that entry type.

    """
    from collections import defaultdict

    if not isinstance(results, list):
        results = [results]

    for entry_type in include_param:
        if entry_type not in ENTRY_COLLECTIONS and entry_type != "":
            raise BadRequest(
                detail=f"'{entry_type}' cannot be identified as a valid relationship type. "
                f"Known relationship types: {sorted(ENTRY_COLLECTIONS.keys())}"
            )

    endpoint_includes = defaultdict(dict)
    for doc in results:
        # convert list of references into dict by ID to only included unique IDs
        if doc is None:
            continue

        relationships = doc.relationships
        if relationships is None:
            continue

        relationships = relationships.dict()
        for entry_type in ENTRY_COLLECTIONS:
            # Skip entry type if it is not in `include_param`
            if entry_type not in include_param:
                continue

            entry_relationship = relationships.get(entry_type, {})
            if entry_relationship is not None:
                refs = entry_relationship.get("data", [])
                for ref in refs:
                    # could check here and raise a warning if any IDs clash
                    endpoint_includes[entry_type][ref["id"]] = ref

    included = {}
    for entry_type in endpoint_includes:
        compound_filter = " OR ".join(
            ['id="{}"'.format(ref_id) for ref_id in endpoint_includes[entry_type]]
        )
        params = EntryListingQueryParams(
            filter=compound_filter,
            response_format="json",
            response_fields=None,
            sort=None,
            page_limit=0,
            page_offset=0,
        )

        # still need to handle pagination
        ref_results, _, _, _ = ENTRY_COLLECTIONS[entry_type].find(params)
        included[entry_type] = ref_results

    # flatten dict by endpoint to list
    return [obj for endp in included.values() for obj in endp]
Beispiel #11
0
async def post_search(request: Request,
                      search: Search) -> QueriesResponseSingle:
    """`POST /search`

    Coordinate a new OPTIMADE query in multiple databases through a gateway:

    1. Search for gateway in DB using `optimade_urls` and `database_ids`
    1. Create [`GatewayCreate`][optimade_gateway.models.gateways.GatewayCreate] model
    1. `POST` gateway resource to get ID - using functionality of `POST /gateways`
    1. Create new [Query][optimade_gateway.models.queries.QueryCreate] resource
    1. `POST` Query resource - using functionality of `POST /queries`
    1. Return `POST /queries` response -
        [`QueriesResponseSingle`][optimade_gateway.models.responses.QueriesResponseSingle]

    """
    databases_collection = await collection_factory(CONFIG.databases_collection
                                                    )
    # NOTE: It may be that the final list of base URLs (`base_urls`) contains the same
    # provider(s), but with differring base URLS, if, for example, a versioned base URL
    # is supplied.
    base_urls = set()

    if search.database_ids:
        databases = await databases_collection.get_multiple(filter={
            "id": {
                "$in": await clean_python_types(search.database_ids)
            }
        })
        base_urls |= {
            get_resource_attribute(database, "attributes.base_url")
            for database in databases if get_resource_attribute(
                database, "attributes.base_url") is not None
        }

    if search.optimade_urls:
        base_urls |= {_ for _ in search.optimade_urls if _ is not None}

    if not base_urls:
        msg = "No (valid) OPTIMADE URLs with:"
        if search.database_ids:
            msg += (
                f"\n  Database IDs: {search.database_ids} and corresponding found URLs: "
                f"{[get_resource_attribute(database, 'attributes.base_url') for database in databases]}"
            )
        if search.optimade_urls:
            msg += f"\n  Passed OPTIMADE URLs: {search.optimade_urls}"
        raise BadRequest(detail=msg)

    # Ensure all URLs are `pydantic.AnyUrl`s
    if not all(isinstance(_, AnyUrl) for _ in base_urls):
        raise InternalServerError(
            "Could unexpectedly not validate all base URLs as proper URLs.")

    databases = await databases_collection.get_multiple(
        filter={"base_url": {
            "$in": await clean_python_types(base_urls)
        }})
    if len(databases) == len(base_urls):
        # At this point it is expected that the list of databases in `databases`
        # is a complete set of databases requested.
        gateway = GatewayCreate(databases=databases)
    elif len(databases) < len(base_urls):
        # There are unregistered databases
        current_base_urls = {
            get_resource_attribute(database, "attributes.base_url")
            for database in databases
        }
        databases.extend([
            LinksResource(
                id=(f"{url.user + '@' if url.user else ''}{url.host}"
                    f"{':' + url.port if url.port else ''}"
                    f"{url.path.rstrip('/') if url.path else ''}").replace(
                        ".", "__"),
                type="links",
                attributes=LinksResourceAttributes(
                    name=(f"{url.user + '@' if url.user else ''}{url.host}"
                          f"{':' + url.port if url.port else ''}"
                          f"{url.path.rstrip('/') if url.path else ''}"),
                    description="",
                    base_url=url,
                    link_type=LinkType.CHILD,
                    homepage=None,
                ),
            ) for url in base_urls - current_base_urls
        ])
    else:
        LOGGER.error(
            "Found more database entries in MongoDB than then number of passed base URLs."
            " This suggests ambiguity in the base URLs of databases stored in MongoDB.\n"
            "  base_urls: %s\n  databases %s",
            base_urls,
            databases,
        )
        raise InternalServerError(
            "Unambiguous base URLs. See logs for more details.")

    gateway = GatewayCreate(databases=databases)
    gateway, created = await resource_factory(gateway)

    if created:
        LOGGER.debug("A new gateway was created for a query (id=%r)",
                     gateway.id)
    else:
        LOGGER.debug("A gateway was found and reused for a query (id=%r)",
                     gateway.id)

    query = QueryCreate(
        endpoint=search.endpoint,
        gateway_id=gateway.id,
        query_parameters=search.query_parameters,
    )
    query, created = await resource_factory(query)

    if created:
        asyncio.create_task(perform_query(url=request.url, query=query))

    collection = await collection_factory(CONFIG.queries_collection)

    return QueriesResponseSingle(
        links=ToplevelLinks(next=None),
        data=query,
        meta=meta_values(
            url=request.url,
            data_returned=1,
            data_available=await collection.acount(),
            more_data_available=False,
            **{f"_{CONFIG.provider.prefix}_created": created},
        ),
    )
Beispiel #12
0
async def get_search(
    request: Request,
    response: Response,
    search_params: SearchQueryParams = Depends(),
    entry_params: EntryListingQueryParams = Depends(),
) -> Union[QueriesResponseSingle, EntryResponseMany, ErrorResponse,
           RedirectResponse]:
    """`GET /search`

    Coordinate a new OPTIMADE query in multiple databases through a gateway:

    1. Create a [`Search`][optimade_gateway.models.search.Search] `POST` data - calling
        `POST /search`.
    1. Wait [`search_params.timeout`][optimade_gateway.queries.params.SearchQueryParams]
        seconds before returning the query, if it has not finished before.
    1. Return query - similar to `GET /queries/{query_id}`.

    This endpoint works similarly to `GET /queries/{query_id}`, where one passes the query
    parameters directly in the URL, instead of first POSTing a query and then going to its
    URL. Hence, a
    [`QueryResponseSingle`][optimade_gateway.models.responses.QueriesResponseSingle] is
    the standard response model for this endpoint.

    If the timeout time is reached and the query has not yet finished, the user is
    redirected to the specific URL for the query.

    If the `as_optimade` query parameter is `True`, the response will be parseable as a
    standard OPTIMADE entry listing endpoint like, e.g., `/structures`.
    For more information see the
    [OPTIMADE specification](https://github.com/Materials-Consortia/OPTIMADE/blob/master/optimade.rst#entry-listing-endpoints).

    """
    try:
        search = Search(
            query_parameters=OptimadeQueryParameters(
                **{
                    field: getattr(entry_params, field)
                    for field in OptimadeQueryParameters.__fields__
                    if getattr(entry_params, field)
                }),
            optimade_urls=search_params.optimade_urls,
            endpoint=search_params.endpoint,
            database_ids=search_params.database_ids,
        )
    except ValidationError as exc:
        raise BadRequest(detail=(
            "A Search object could not be created from the given URL query "
            f"parameters. Error(s): {exc.errors}")) from exc

    queries_response = await post_search(request, search=search)

    if not queries_response.data:
        LOGGER.error("QueryResource not found in POST /search response:\n%s",
                     queries_response)
        raise RuntimeError(
            "Expected the response from POST /search to return a QueryResource, it did "
            "not")

    once = True
    start_time = time()
    while (  # pylint: disable=too-many-nested-blocks
            time() < (start_time + search_params.timeout) or once):
        # Make sure to run this at least once (e.g., if timeout=0)
        once = False

        collection = await collection_factory(CONFIG.queries_collection)

        query: QueryResource = await collection.get_one(
            **{"filter": {
                "id": queries_response.data.id
            }})

        if query.attributes.state == QueryState.FINISHED:
            if query.attributes.response and query.attributes.response.errors:
                for error in query.attributes.response.errors:
                    if error.status:
                        for part in error.status.split(" "):
                            try:
                                response.status_code = int(part)
                                break
                            except ValueError:
                                pass
                        if response.status_code and response.status_code >= 300:
                            break
                else:
                    response.status_code = 500

            if search_params.as_optimade:
                return await query.response_as_optimade(url=request.url)

            return QueriesResponseSingle(
                links=ToplevelLinks(next=None),
                data=query,
                meta=meta_values(
                    url=request.url,
                    data_returned=1,
                    data_available=await collection.acount(),
                    more_data_available=False,
                ),
            )

        await asyncio.sleep(0.1)

    # The query has not yet succeeded and we're past the timeout time -> Redirect to
    # /queries/<id>
    return RedirectResponse(query.links.self)
    async def afind(
        self,
        params:
        "Optional[Union[EntryListingQueryParams, SingleEntryQueryParams]]" = None,
        criteria: "Optional[Dict[str, Any]]" = None,
    ) -> "Tuple[Union[List[EntryResource], EntryResource, None], int, bool, Set[str], Set[str]]":
        """Perform the query on the underlying MongoDB Collection, handling projection
        and pagination of the output.

        This is the asynchronous version of the parent class method named `count()`.

        Either provide `params` or `criteria`. Not both, but at least one.

        Parameters:
            params: URL query parameters, either from a general entry endpoint or a
                single-entry endpoint.
            criteria: Already handled/parsed URL query parameters.

        Returns:
            A list of entry resource objects, how much data was returned for the query,
            whether more data is available with pagination, and fields (excluded,
            included).

        """
        if (params is None and criteria is None) or (params is not None
                                                     and criteria is not None):
            raise ValueError(
                "Exacly one of either `params` and `criteria` must be specified."
            )

        # Set single_entry to False, this is done since if criteria is defined,
        # this is an unknown factor - better to then get a list of results.
        single_entry = False
        if criteria is None:
            criteria = await self.ahandle_query_params(params)
        else:
            single_entry = isinstance(params, SingleEntryQueryParams)

        response_fields = criteria.pop("fields", self.all_fields)

        results, data_returned, more_data_available = await self._arun_db_query(
            criteria=criteria,
            single_entry=single_entry,
        )

        if single_entry:
            results = results[
                0] if results else None  # type: ignore[assignment]

            if data_returned > 1:
                raise NotFound(detail=(
                    f"Instead of a single entry, {data_returned} entries were found"
                ), )

        include_fields = (response_fields -
                          self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS)
        bad_optimade_fields = set()
        bad_provider_fields = set()
        for field in include_fields:
            if field not in self.resource_mapper.ALL_ATTRIBUTES:
                if field.startswith("_"):
                    if any(
                            field.startswith(f"_{prefix}_") for prefix in
                            self.resource_mapper.SUPPORTED_PREFIXES):
                        bad_provider_fields.add(field)
                else:
                    bad_optimade_fields.add(field)

        if bad_provider_fields:
            warn(
                UnknownProviderProperty(detail=(
                    "Unrecognised field(s) for this provider requested in "
                    f"`response_fields`: {bad_provider_fields}.")))

        if bad_optimade_fields:
            raise BadRequest(detail=(
                "Unrecognised OPTIMADE field(s) in requested `response_fields`: "
                f"{bad_optimade_fields}."))

        if results:
            results = await self.resource_mapper.adeserialize(results)

        return (  # type: ignore[return-value]
            results,
            data_returned,
            more_data_available,
            self.all_fields - response_fields,
            include_fields,
        )