async def create_one(self,
                         resource: "EntryResourceCreate") -> "EntryResource":
        """Create a new document in the MongoDB collection based on query parameters.

        Update the newly created document with an `"id"` field.
        The value will be the string representation of the `"_id"` field.
        This will only be done if `"id"` is not already present in `resource`.

        Parameters:
            resource: The resource to be created.

        Returns:
            The newly created document as a pydantic model entry resource.

        """
        resource.last_modified = datetime.utcnow()
        result = await self.collection.insert_one(await clean_python_types(
            resource.dict(exclude_unset=True)))
        LOGGER.debug(
            "Inserted resource %r in DB collection %s with ID %s",
            resource,
            self.collection.name,
            result.inserted_id,
        )

        if not resource.id:
            LOGGER.debug(
                "Updating resource with an `id` field equal to str(id_).")
            await self.collection.update_one(
                {"_id": result.inserted_id},
                {"$set": {
                    "id": str(result.inserted_id)
                }})

        return self.resource_cls(**self.resource_mapper.map_back(
            await self.collection.find_one({"_id": result.inserted_id})))
Ejemplo n.º 2
0
async def ci_dev_startup() -> None:
    """Function to run at app startup - only relevant for CI or development to add test
    data."""
    if bool(os.getenv("CI", "")):
        LOGGER.info(
            "CI detected - Will load test gateways (after dropping the collection)!"
        )
    elif os.getenv("OPTIMADE_MONGO_DATABASE", "") == "optimade_gateway_dev":
        LOGGER.info(
            "Running in development mode - Will load test gateways (after dropping the"
            "collection)!")
    else:
        LOGGER.debug("Not in CI or development mode - will start normally.")
        return

    # Add test gateways
    import json
    from optimade_gateway.mongo.database import MONGO_DB
    from pathlib import Path

    test_data = (Path(__file__).parent.parent.joinpath(
        ".ci/test_gateways.json").resolve())

    await MONGO_DB[CONFIG.gateways_collection].drop()

    if await MONGO_DB[CONFIG.gateways_collection].count_documents({}) != 0:
        raise RuntimeError(
            f"Unexpectedly found documents in the {CONFIG.gateways_collection!r} Mongo"
            " collection after dropping it ! Found number of documents: "
            f"{await MONGO_DB[CONFIG.gateways_collection].count_documents({})}"
        )

    if not test_data.exists():
        raise FileNotFoundError(
            f"Could not find test data file with test gateways at {test_data} !"
        )

    with open(test_data, encoding="utf8") as handle:
        data = json.load(handle)
    await MONGO_DB[CONFIG.gateways_collection].insert_many(data)
Ejemplo n.º 3
0
async def load_optimade_providers_databases() -> None:  # pylint: disable=too-many-branches,too-many-statements,too-many-locals
    """Load in the providers' OPTIMADE databases from Materials-Consortia

    Utilize the Materials-Consortia list of OPTIMADE providers at
    [https://providers.optimade.org](https://providers.optimade.org).
    Load in all databases with a valid base URL.
    """
    import asyncio

    import httpx
    from optimade import __api_version__
    from optimade.models import LinksResponse
    from optimade.models.links import LinkType
    from optimade.server.routers.utils import BASE_URL_PREFIXES

    from optimade_gateway.common.utils import clean_python_types, get_resource_attribute
    from optimade_gateway.models.databases import DatabaseCreate
    from optimade_gateway.queries.perform import db_get_all_resources
    from optimade_gateway.routers.utils import resource_factory

    if not CONFIG.load_optimade_providers_databases:
        LOGGER.debug(
            "Will not load databases from Materials-Consortia list of providers."
        )
        return

    if TYPE_CHECKING or bool(os.getenv("MKDOCS_BUILD",
                                       "")):  # pragma: no cover
        providers: "Union[httpx.Response, LinksResponse]"

    async with httpx.AsyncClient() as client:
        providers = await client.get(
            f"https://providers.optimade.org/v{__api_version__.split('.', maxsplit=1)[0]}"
            "/links")

    if providers.is_error:
        LOGGER.warning(
            "Response from Materials-Consortia's list of OPTIMADE providers was not "
            "successful (status code != 200). No databases will therefore be added at "
            "server startup.")
        return

    LOGGER.info(
        "Registering Materials-Consortia list of OPTIMADE providers' databases."
    )

    providers = LinksResponse(**providers.json())

    valid_providers = []
    for provider in providers.data:
        if get_resource_attribute(provider, "id") in ("exmpl", "optimade"):
            LOGGER.info(
                "- %s (id=%r) - Skipping: Not a real provider.",
                get_resource_attribute(provider, "attributes.name", "N/A"),
                get_resource_attribute(provider, "id"),
            )
            continue

        if not get_resource_attribute(provider, "attributes.base_url"):
            LOGGER.info(
                "- %s (id=%r) - Skipping: No base URL information.",
                get_resource_attribute(provider, "attributes.name", "N/A"),
                get_resource_attribute(provider, "id"),
            )
            continue

        valid_providers.append(provider)

    # Run queries to each database using the supported major versioned base URL to get a
    # list of the provider's databases.
    # There is no need to use ThreadPoolExecutor here, since we want this to block
    # everything and then finish, before the server actually starts up.
    provider_queries = [
        asyncio.create_task(
            db_get_all_resources(
                database=provider,
                endpoint="links",
                response_model=LinksResponse,
            )) for provider in valid_providers
    ]

    for query in asyncio.as_completed(provider_queries):
        provider_databases, provider = await query

        LOGGER.info(
            "- %s (id=%r) - Processing",
            get_resource_attribute(provider, "attributes.name", "N/A"),
            get_resource_attribute(provider, "id"),
        )
        if not provider_databases:
            LOGGER.info("  - No OPTIMADE databases found.")
            continue

        provider_databases = [
            db for db in provider_databases if await clean_python_types(
                get_resource_attribute(db, "attributes.link_type", "")) ==
            LinkType.CHILD.value
        ]

        if not provider_databases:
            LOGGER.info("  - No OPTIMADE databases found.")
            continue

        for database in provider_databases:
            if not get_resource_attribute(database, "attributes.base_url"):
                LOGGER.info(
                    "  - %s (id=%r) - Skipping: No base URL information.",
                    get_resource_attribute(database, "attributes.name", "N/A"),
                    get_resource_attribute(database, "id"),
                )
                continue

            LOGGER.info(
                "  - %s (id=%r) - Checking versioned base URL and /structures",
                get_resource_attribute(database, "attributes.name", "N/A"),
                get_resource_attribute(database, "id"),
            )

            async with httpx.AsyncClient() as client:
                try:
                    db_response = await client.get(
                        f"{str(get_resource_attribute(database, 'attributes.base_url')).rstrip('/')}"  # pylint: disable=line-too-long
                        f"{BASE_URL_PREFIXES['major']}/structures", )
                except httpx.ReadTimeout:
                    LOGGER.info(
                        "  - %s (id=%r) - Skipping: Timeout while requesting "
                        "%s/structures.",
                        get_resource_attribute(database, "attributes.name",
                                               "N/A"),
                        get_resource_attribute(database, "id"),
                        BASE_URL_PREFIXES["major"],
                    )
                    continue
            if db_response.status_code != 200:
                LOGGER.info(
                    "  - %s (id=%r) - Skipping: Response from %s/structures is not "
                    "200 OK.",
                    get_resource_attribute(database, "attributes.name", "N/A"),
                    get_resource_attribute(database, "id"),
                    BASE_URL_PREFIXES["major"],
                )
                continue

            new_id = (f"{get_resource_attribute(provider, 'id')}"
                      f"/{get_resource_attribute(database, 'id')}" if
                      len(provider_databases) > 1 else get_resource_attribute(
                          database, "id"))
            registered_database, _ = await resource_factory(
                DatabaseCreate(
                    id=new_id,
                    **await clean_python_types(
                        get_resource_attribute(database, "attributes", {})),
                ))
            LOGGER.info(
                "  - %s (id=%r) - Registered database with id=%r",
                get_resource_attribute(database, "attributes.name", "N/A"),
                get_resource_attribute(database, "id"),
                registered_database.id,
            )
async def db_get_all_resources(
    database: "Union[LinksResource, Dict[str, Any]]",
    endpoint: str,
    response_model: "EntryResponseMany",
    query_params: str = "",
    raw_url: str = None,
) -> "Tuple[List[Union[EntryResource, Dict[str, Any]]], Union[LinksResource, Dict[str, Any]]]":  # pylint: disable=line-too-long
    """Recursively retrieve all resources from an entry-listing endpoint

    This function keeps pulling the `links.next` link if `meta.more_data_available` is
    `True` to ultimately retrieve *all* entries for `endpoint`.

    !!! warning
        This function can be dangerous if an endpoint with hundreds or thousands of
        entries is requested.

    Parameters:
        database: The OPTIMADE implementation to be queried.
            It **must** have a valid base URL and id.
        endpoint: The entry-listing endpoint, e.g., `"structures"`.
        response_model: The expected OPTIMADE pydantic response model, e.g.,
            `optimade.models.StructureResponseMany`.
        query_params: URL query parameters to pass to the database.
        raw_url: A raw URL to use straight up instead of deriving a URL from `database`,
            `endpoint`, and `query_params`.

    Returns:
        A collected list of successful responses' `data` value and the `database`'s ID.

    """
    resulting_resources = []

    response, _ = db_find(
        database=database,
        endpoint=endpoint,
        response_model=response_model,
        query_params=query_params,
        raw_url=raw_url,
    )

    if isinstance(response, ErrorResponse):
        # An errored response will result in no databases from a provider.
        LOGGER.error(
            "Error while querying database (id=%r). Full response: %s",
            get_resource_attribute(database, "id"),
            response.json(indent=2),
        )
        return [], database

    resulting_resources.extend(response.data)

    if response.meta.more_data_available:
        next_page = get_resource_attribute(response, "links.next")
        if next_page is None:
            LOGGER.error(
                "Could not find a 'next' link for an OPTIMADE query request to %r "
                "(id=%r). Cannot get all resources from /%s, even though this was asked "
                "and `more_data_available` is `True` in the response.",
                get_resource_attribute(database, "attributes.name", "N/A"),
                get_resource_attribute(database, "id"),
                endpoint,
            )
            return resulting_resources, database

        more_resources, _ = await db_get_all_resources(
            database=database,
            endpoint=endpoint,
            response_model=response_model,
            query_params=query_params,
            raw_url=next_page,
        )
        resulting_resources.extend(more_resources)

    return resulting_resources, database
"""Initialize the MongoDB database."""
from os import getenv
from typing import TYPE_CHECKING

from motor.motor_asyncio import AsyncIOMotorClient

from optimade_gateway.common.config import CONFIG
from optimade_gateway.common.logger import LOGGER

if TYPE_CHECKING or bool(getenv("MKDOCS_BUILD", "")):  # pragma: no cover
    # pylint: disable=unused-import
    from pymongo.database import Database
    from pymongo.mongo_client import MongoClient

MONGO_CLIENT: "MongoClient" = AsyncIOMotorClient(
    CONFIG.mongo_uri,
    appname="optimade-gateway",
    readConcernLevel="majority",
    readPreference="primary",
    w="majority",
)
"""The MongoDB motor client."""

MONGO_DB: "Database" = MONGO_CLIENT[CONFIG.mongo_database]
"""The MongoDB motor database.
This is a representation of the database used for the gateway service."""

LOGGER.info("Using: Real MongoDB (motor) at %s", CONFIG.mongo_uri)
LOGGER.info("Database: %s", CONFIG.mongo_database)
Ejemplo n.º 6
0
async def post_search(request: Request,
                      search: Search) -> QueriesResponseSingle:
    """`POST /search`

    Coordinate a new OPTIMADE query in multiple databases through a gateway:

    1. Search for gateway in DB using `optimade_urls` and `database_ids`
    1. Create [`GatewayCreate`][optimade_gateway.models.gateways.GatewayCreate] model
    1. `POST` gateway resource to get ID - using functionality of `POST /gateways`
    1. Create new [Query][optimade_gateway.models.queries.QueryCreate] resource
    1. `POST` Query resource - using functionality of `POST /queries`
    1. Return `POST /queries` response -
        [`QueriesResponseSingle`][optimade_gateway.models.responses.QueriesResponseSingle]

    """
    databases_collection = await collection_factory(CONFIG.databases_collection
                                                    )
    # NOTE: It may be that the final list of base URLs (`base_urls`) contains the same
    # provider(s), but with differring base URLS, if, for example, a versioned base URL
    # is supplied.
    base_urls = set()

    if search.database_ids:
        databases = await databases_collection.get_multiple(filter={
            "id": {
                "$in": await clean_python_types(search.database_ids)
            }
        })
        base_urls |= {
            get_resource_attribute(database, "attributes.base_url")
            for database in databases if get_resource_attribute(
                database, "attributes.base_url") is not None
        }

    if search.optimade_urls:
        base_urls |= {_ for _ in search.optimade_urls if _ is not None}

    if not base_urls:
        msg = "No (valid) OPTIMADE URLs with:"
        if search.database_ids:
            msg += (
                f"\n  Database IDs: {search.database_ids} and corresponding found URLs: "
                f"{[get_resource_attribute(database, 'attributes.base_url') for database in databases]}"
            )
        if search.optimade_urls:
            msg += f"\n  Passed OPTIMADE URLs: {search.optimade_urls}"
        raise BadRequest(detail=msg)

    # Ensure all URLs are `pydantic.AnyUrl`s
    if not all(isinstance(_, AnyUrl) for _ in base_urls):
        raise InternalServerError(
            "Could unexpectedly not validate all base URLs as proper URLs.")

    databases = await databases_collection.get_multiple(
        filter={"base_url": {
            "$in": await clean_python_types(base_urls)
        }})
    if len(databases) == len(base_urls):
        # At this point it is expected that the list of databases in `databases`
        # is a complete set of databases requested.
        gateway = GatewayCreate(databases=databases)
    elif len(databases) < len(base_urls):
        # There are unregistered databases
        current_base_urls = {
            get_resource_attribute(database, "attributes.base_url")
            for database in databases
        }
        databases.extend([
            LinksResource(
                id=(f"{url.user + '@' if url.user else ''}{url.host}"
                    f"{':' + url.port if url.port else ''}"
                    f"{url.path.rstrip('/') if url.path else ''}").replace(
                        ".", "__"),
                type="links",
                attributes=LinksResourceAttributes(
                    name=(f"{url.user + '@' if url.user else ''}{url.host}"
                          f"{':' + url.port if url.port else ''}"
                          f"{url.path.rstrip('/') if url.path else ''}"),
                    description="",
                    base_url=url,
                    link_type=LinkType.CHILD,
                    homepage=None,
                ),
            ) for url in base_urls - current_base_urls
        ])
    else:
        LOGGER.error(
            "Found more database entries in MongoDB than then number of passed base URLs."
            " This suggests ambiguity in the base URLs of databases stored in MongoDB.\n"
            "  base_urls: %s\n  databases %s",
            base_urls,
            databases,
        )
        raise InternalServerError(
            "Unambiguous base URLs. See logs for more details.")

    gateway = GatewayCreate(databases=databases)
    gateway, created = await resource_factory(gateway)

    if created:
        LOGGER.debug("A new gateway was created for a query (id=%r)",
                     gateway.id)
    else:
        LOGGER.debug("A gateway was found and reused for a query (id=%r)",
                     gateway.id)

    query = QueryCreate(
        endpoint=search.endpoint,
        gateway_id=gateway.id,
        query_parameters=search.query_parameters,
    )
    query, created = await resource_factory(query)

    if created:
        asyncio.create_task(perform_query(url=request.url, query=query))

    collection = await collection_factory(CONFIG.queries_collection)

    return QueriesResponseSingle(
        links=ToplevelLinks(next=None),
        data=query,
        meta=meta_values(
            url=request.url,
            data_returned=1,
            data_available=await collection.acount(),
            more_data_available=False,
            **{f"_{CONFIG.provider.prefix}_created": created},
        ),
    )
Ejemplo n.º 7
0
async def get_search(
    request: Request,
    response: Response,
    search_params: SearchQueryParams = Depends(),
    entry_params: EntryListingQueryParams = Depends(),
) -> Union[QueriesResponseSingle, EntryResponseMany, ErrorResponse,
           RedirectResponse]:
    """`GET /search`

    Coordinate a new OPTIMADE query in multiple databases through a gateway:

    1. Create a [`Search`][optimade_gateway.models.search.Search] `POST` data - calling
        `POST /search`.
    1. Wait [`search_params.timeout`][optimade_gateway.queries.params.SearchQueryParams]
        seconds before returning the query, if it has not finished before.
    1. Return query - similar to `GET /queries/{query_id}`.

    This endpoint works similarly to `GET /queries/{query_id}`, where one passes the query
    parameters directly in the URL, instead of first POSTing a query and then going to its
    URL. Hence, a
    [`QueryResponseSingle`][optimade_gateway.models.responses.QueriesResponseSingle] is
    the standard response model for this endpoint.

    If the timeout time is reached and the query has not yet finished, the user is
    redirected to the specific URL for the query.

    If the `as_optimade` query parameter is `True`, the response will be parseable as a
    standard OPTIMADE entry listing endpoint like, e.g., `/structures`.
    For more information see the
    [OPTIMADE specification](https://github.com/Materials-Consortia/OPTIMADE/blob/master/optimade.rst#entry-listing-endpoints).

    """
    try:
        search = Search(
            query_parameters=OptimadeQueryParameters(
                **{
                    field: getattr(entry_params, field)
                    for field in OptimadeQueryParameters.__fields__
                    if getattr(entry_params, field)
                }),
            optimade_urls=search_params.optimade_urls,
            endpoint=search_params.endpoint,
            database_ids=search_params.database_ids,
        )
    except ValidationError as exc:
        raise BadRequest(detail=(
            "A Search object could not be created from the given URL query "
            f"parameters. Error(s): {exc.errors}")) from exc

    queries_response = await post_search(request, search=search)

    if not queries_response.data:
        LOGGER.error("QueryResource not found in POST /search response:\n%s",
                     queries_response)
        raise RuntimeError(
            "Expected the response from POST /search to return a QueryResource, it did "
            "not")

    once = True
    start_time = time()
    while (  # pylint: disable=too-many-nested-blocks
            time() < (start_time + search_params.timeout) or once):
        # Make sure to run this at least once (e.g., if timeout=0)
        once = False

        collection = await collection_factory(CONFIG.queries_collection)

        query: QueryResource = await collection.get_one(
            **{"filter": {
                "id": queries_response.data.id
            }})

        if query.attributes.state == QueryState.FINISHED:
            if query.attributes.response and query.attributes.response.errors:
                for error in query.attributes.response.errors:
                    if error.status:
                        for part in error.status.split(" "):
                            try:
                                response.status_code = int(part)
                                break
                            except ValueError:
                                pass
                        if response.status_code and response.status_code >= 300:
                            break
                else:
                    response.status_code = 500

            if search_params.as_optimade:
                return await query.response_as_optimade(url=request.url)

            return QueriesResponseSingle(
                links=ToplevelLinks(next=None),
                data=query,
                meta=meta_values(
                    url=request.url,
                    data_returned=1,
                    data_available=await collection.acount(),
                    more_data_available=False,
                ),
            )

        await asyncio.sleep(0.1)

    # The query has not yet succeeded and we're past the timeout time -> Redirect to
    # /queries/<id>
    return RedirectResponse(query.links.self)
Ejemplo n.º 8
0
async def resource_factory(  # pylint: disable=too-many-branches
    create_resource: "Union[DatabaseCreate, GatewayCreate, QueryCreate]",
) -> "Tuple[Union[LinksResource, GatewayResource, QueryResource], bool]":
    """Get or create a resource

    Currently supported resources:

    - `"databases"` ([`DatabaseCreate`][optimade_gateway.models.databases.DatabaseCreate]
        ->
        [`LinksResource`](https://www.optimade.org/optimade-python-tools/api_reference/models/links/#optimade.models.links.LinksResource))
    - `"gateways"` ([`GatewayCreate`][optimade_gateway.models.gateways.GatewayCreate] ->
        [`GatewayResource`][optimade_gateway.models.gateways.GatewayResource])
    - `"queries"` ([`QueryCreate`][optimade_gateway.models.queries.QueryCreate] ->
        [`QueryResource`][optimade_gateway.models.queries.QueryResource])

    For each of the resources, "uniqueness" is determined in the following way:

    === "Databases"
        The `base_url` field is considered unique across all databases.

        If a `base_url` is provided via a
        [`Link`](https://www.optimade.org/optimade-python-tools/api_reference/models/jsonapi/#optimade.models.jsonapi.Link)
        model, the `base_url.href` value is used to query the MongoDB.

    === "Gateways"
        The collected list of `databases.attributes.base_url` values is considered unique
        across all gateways.

        In the database, the search is done as a combination of the length/size of the
        `databases`' Python list/MongoDB array and a match on all (using the MongoDB
        `$all` operator) of the
        [`databases.attributes.base_url`](https://www.optimade.org/optimade-python-tools/api_reference/models/links/#optimade.models.links.LinksResourceAttributes.base_url)
        element values, when compared with the `create_resource`.

        !!! important
            The `database_ids` attribute **must not** contain values that are not also
            included in the `databases` attribute, in the form of the IDs for the
            individual databases. If this should be the case an
            [`OptimadeGatewayError`][optimade_gateway.common.exceptions.OptimadeGatewayError]
            will be thrown.

    === "Queries"
        The `gateway_id`, `query_parameters`, and `endpoint` fields are collectively
        considered to define uniqueness for a
        [`QueryResource`][optimade_gateway.models.queries.QueryResource] in the MongoDB
        collection.

        !!! attention
            Only the `/structures` entry endpoint can be queried with multiple expected
            responses.

            This means the `endpoint` field defaults to `"structures"`, i.e., the
            [`StructureResource`](https://www.optimade.org/optimade-python-tools/all_models/#optimade.models.structures.StructureResource)
            resource model.

    Parameters:
        create_resource: The resource to be retrieved or created anew.

    Returns:
        Two things in a tuple:

        - Either a [`GatewayResource`][optimade_gateway.models.gateways.GatewayResource];
            a [`QueryResource`][optimade_gateway.models.queries.QueryResource]; or a
            [`LinksResource`](https://www.optimade.org/optimade-python-tools/api_reference/models/links/#optimade.models.links.LinksResource)
            and
        - whether or not the resource was newly created.

    """
    created = False

    if isinstance(create_resource, DatabaseCreate):
        collection_name = CONFIG.databases_collection

        base_url = get_resource_attribute(create_resource, "base_url")

        mongo_query = {
            "$or": [
                {"base_url": {"$eq": base_url}},
                {"base_url.href": {"$eq": base_url}},
            ]
        }
    elif isinstance(create_resource, GatewayCreate):
        collection_name = CONFIG.gateways_collection

        # One MUST have taken care of database_ids prior to calling `resource_factory()`
        database_attr_ids = {_.id for _ in create_resource.databases or []}
        unknown_ids = {
            database_id
            for database_id in create_resource.database_ids or []
            if database_id not in database_attr_ids
        }
        if unknown_ids:
            raise OptimadeGatewayError(
                "When using `resource_factory()` for `GatewayCreate`, `database_ids` MUST"
                f" not include unknown IDs. Passed unknown IDs: {unknown_ids}"
            )

        mongo_query = {
            "databases": {"$size": len(create_resource.databases)},
            "databases.attributes.base_url": {
                "$all": [_.attributes.base_url for _ in create_resource.databases or []]
            },
        }
    elif isinstance(create_resource, QueryCreate):
        collection_name = CONFIG.queries_collection

        # Currently only /structures entry endpoints can be queried with multiple
        # expected responses.
        create_resource.endpoint = (
            create_resource.endpoint if create_resource.endpoint else "structures"
        )

        mongo_query = {
            "gateway_id": {"$eq": create_resource.gateway_id},
            "query_parameters": {"$eq": create_resource.query_parameters},
            "endpoint": {"$eq": create_resource.endpoint},
        }
    else:
        raise TypeError(
            "create_resource must be either a GatewayCreate or QueryCreate object not "
            f"{type(create_resource)!r}"
        )

    collection = await collection_factory(collection_name)
    result, data_returned, more_data_available, _, _ = await collection.afind(
        criteria={"filter": await clean_python_types(mongo_query)}
    )

    if more_data_available:
        raise OptimadeGatewayError(
            "more_data_available MUST be False for a single entry response, however it "
            f"is {more_data_available}"
        )

    if result:
        if data_returned > 1:
            raise OptimadeGatewayError(
                f"More than one {result[0].type} were found. IDs of found "
                f"{result[0].type}: {[_.id for _ in result]}"
            )
        if isinstance(result, list):
            result = result[0]
    else:
        if isinstance(create_resource, DatabaseCreate):
            # Set required `LinksResourceAttributes` values if not set
            if not create_resource.description:
                create_resource.description = (
                    f"{create_resource.name} created by OPTIMADE gateway database "
                    "registration."
                )
            if not create_resource.link_type:
                create_resource.link_type = LinkType.EXTERNAL
            if not create_resource.homepage:
                create_resource.homepage = None
        elif isinstance(create_resource, GatewayCreate):
            # Do not store `database_ids`
            if "database_ids" in create_resource.__fields_set__:
                create_resource.database_ids = None
                create_resource.__fields_set__.remove("database_ids")
        elif isinstance(create_resource, QueryCreate):
            create_resource.state = QueryState.CREATED
        result = await collection.create_one(create_resource)
        LOGGER.debug("Created new %s: %r", result.type, result)
        created = True

    return result, created
async def process_db_response(
    response: "Union[ErrorResponse, EntryResponseMany, EntryResponseOne]",
    database_id: str,
    query: "QueryResource",
    gateway: "GatewayResource",
) -> "Union[List[EntryResource], List[Dict[str, Any]], EntryResource, Dict[str, Any], None]":  # pylint: disable=line-too-long
    """Process an OPTIMADE database response.

    The passed `query` will be updated with the top-level `meta` information:
    `data_available`, `data_returned`, and `more_data_available`.

    Since, only either `data` or `errors` should ever be present, one or the other will
    be either an empty list or `None`.

    Parameters:
        response: The OPTIMADE database response to be processed.
        database_id: The database's `id` under which the returned resources or errors
            will be delivered.
        query: A resource representing the performed query.
        gateway: A resource representing the gateway that was queried.

    Returns:
        The response's `data`.

    """
    results = []
    errors = []

    LOGGER.debug("Starting to process database_id: %s", database_id)

    if isinstance(response, ErrorResponse):
        for error in response.errors:
            if isinstance(error.id,
                          str) and error.id.startswith("OPTIMADE_GATEWAY"):
                warn(error.detail, OptimadeGatewayWarning)
            else:
                # The model `ErrorResponse` does not allow the objects in the top-level
                # `errors` list to be parsed as dictionaries - they must be a pydantic
                # model.
                meta_error = {}
                if error.meta:
                    meta_error = error.meta.dict()
                meta_error.update({
                    f"_{CONFIG.provider.prefix}_source_gateway": {
                        "id": gateway.id,
                        "type": gateway.type,
                        "links": {
                            "self": gateway.links.self
                        },
                    },
                    f"_{CONFIG.provider.prefix}_source_database": {
                        "id": database_id,
                        "type": "links",
                        "links": {
                            "self":
                            (str(gateway.links.self).split("gateways",
                                                           maxsplit=1)[0] +
                             f"databases/{database_id}")
                        },
                    },
                })
                error.meta = Meta(**meta_error)
                errors.append(error)
        data_returned = 0
        more_data_available = False
    else:
        results = response.data

        if isinstance(results, list):
            data_returned = response.meta.data_returned or len(results)
        else:
            data_returned = response.meta.data_returned or (0 if not results
                                                            else 1)

        more_data_available = response.meta.more_data_available or False

    data_available = response.meta.data_available or 0

    extra_updates = {
        "$inc": {
            "response.meta.data_available": data_available,
            "response.meta.data_returned": data_returned,
        }
    }
    if not get_resource_attribute(
            query,
            "attributes.response.meta.more_data_available",
            False,
            disambiguate=False,  # Extremely minor speed-up
    ):
        # Keep it True, if set to True once.
        extra_updates.update({
            "$set": {
                "response.meta.more_data_available": more_data_available
            }
        })

    # This ensures an empty list under `response.data.{database_id}` is returned if the
    # case is simply that there are no results to return.
    if errors:
        extra_updates.update(
            {"$addToSet": {
                "response.errors": {
                    "$each": errors
                }
            }})
    await update_query(
        query,
        f"response.data.{database_id}",
        results,
        operator=None,
        **extra_updates,
    )

    return results
Ejemplo n.º 10
0
async def update_query(  # pylint: disable=too-many-branches
    query: "QueryResource",
    field: str,
    value: "Any",
    operator: "Optional[str]" = None,
    **mongo_kwargs: "Any",
) -> None:
    """Update a query's `field` attribute with `value`.

    If `field` is a dot-separated value, then only the last field part may be a
    non-pre-existing field. Otherwise a `KeyError` or `AttributeError` will be raised.

    !!! note
        This can *only* update a field for a query's `attributes`, i.e., this function
        cannot update `id`, `type` or any other top-level resource field.

    !!! important
        `mongo_kwargs` will not be considered for updating the pydantic model instance.

    Parameters:
        query: The query to be updated.
        field: The `attributes` field (key) to be set.
            This can be a dot-separated key value to signify embedded fields.

            **Example**: `response.meta`.
        value: The (possibly) new value for `field`.
        operator: A MongoDB operator to be used for updating `field` with `value`.
        **mongo_kwargs: Further MongoDB update filters.

    """
    operator = operator or "$set"

    if operator and not operator.startswith("$"):
        operator = f"${operator}"

    update_time = datetime.utcnow()

    update_kwargs = {"$set": {"last_modified": update_time}}

    if mongo_kwargs:
        update_kwargs.update(mongo_kwargs)

    if operator and operator == "$set":
        update_kwargs["$set"].update({field: value})
    elif operator:
        if operator in update_kwargs:
            update_kwargs[operator].update({field: value})
        else:
            update_kwargs.update({operator: {field: value}})

    # MongoDB
    collection = await collection_factory(CONFIG.queries_collection)
    result: "UpdateResult" = await collection.collection.update_one(
        filter={"id": {
            "$eq": query.id
        }},
        update=await clean_python_types(update_kwargs),
    )
    if result.matched_count != 1:
        LOGGER.error(
            ("matched_count should have been exactly 1, it was: %s. "
             "Returned update_one result: %s"),
            result.matched_count,
            result.raw_result,
        )

    # Pydantic model instance
    query.attributes.last_modified = update_time
    if "." in field:
        field_list = field.split(".")
        field = getattr(query.attributes, field_list[0])
        for field_part in field_list[1:-1]:
            if isinstance(field, dict):
                field = field.get(field_part)
            else:
                field = getattr(field, field_part)
        if isinstance(field, dict):
            field[field_list[-1]] = value
        else:
            setattr(field, field_list[-1], value)
    else:
        setattr(query.attributes, field, value)