예제 #1
0
def get_entries(
    collection: EntryCollection,
    response: EntryResponseMany,
    request: Request,
    params: EntryListingQueryParams,
) -> EntryResponseMany:
    """Generalized /{entry} endpoint getter"""
    from optimade.server.routers import ENTRY_COLLECTIONS

    (
        results,
        data_returned,
        more_data_available,
        fields,
        include_fields,
    ) = collection.find(params)

    include = []
    if getattr(params, "include", False):
        include.extend(params.include.split(","))
    included = get_included_relationships(results, ENTRY_COLLECTIONS, include)

    if more_data_available:
        # Deduce the `next` link from the current request
        query = urllib.parse.parse_qs(request.url.query)
        query["page_offset"] = int(query.get("page_offset",
                                             [0])[0]) + len(results)
        urlencoded = urllib.parse.urlencode(query, doseq=True)
        base_url = get_base_url(request.url)

        links = ToplevelLinks(
            next=f"{base_url}{request.url.path}?{urlencoded}")
    else:
        links = ToplevelLinks(next=None)

    if fields or include_fields:
        results = handle_response_fields(results, fields, include_fields)

    return response(
        links=links,
        data=results,
        meta=meta_values(
            url=request.url,
            data_returned=data_returned,
            data_available=len(collection),
            more_data_available=more_data_available,
        ),
        included=included,
    )
예제 #2
0
async def post_queries(
    request: Request,
    query: QueryCreate,
) -> QueriesResponseSingle:
    """`POST /queries`

    Create or return existing gateway query according to `query`.
    """
    await validate_resource(
        await collection_factory(CONFIG.gateways_collection), query.gateway_id)

    result, created = await resource_factory(query)

    if created:
        asyncio.create_task(perform_query(url=request.url, query=result))

    collection = await collection_factory(CONFIG.queries_collection)

    return QueriesResponseSingle(
        links=ToplevelLinks(next=None),
        data=result,
        meta=meta_values(
            url=request.url,
            data_returned=1,
            data_available=await collection.acount(),
            more_data_available=False,
            **{f"_{CONFIG.provider.prefix}_created": created},
        ),
    )
예제 #3
0
def get_entries(
    backend: orm.implementation.Backend,
    collection: AiidaCollection,
    response: EntryResponseMany,
    request: Request,
    params: EntryListingQueryParams,
) -> EntryResponseMany:
    """Generalized /{entry} endpoint getter"""
    (
        results,
        data_returned,
        more_data_available,
        data_available,
        fields,
    ) = collection.find(backend, params)

    pagination = handle_pagination(request=request,
                                   more_data_available=more_data_available,
                                   nresults=len(results))

    if fields:
        results = handle_response_fields(results, fields)

    return response(
        links=ToplevelLinks(**pagination),
        data=results,
        meta=meta_values(str(request.url), data_returned, data_available,
                         more_data_available),
    )
예제 #4
0
def get_single_entry(  # pylint: disable=too-many-arguments
    backend: orm.implementation.Backend,
    collection: AiidaCollection,
    entry_id: str,
    response: EntryResponseOne,
    request: Request,
    params: SingleEntryQueryParams,
) -> EntryResponseOne:
    params.filter = f"id={entry_id}"
    results, data_returned, more_data_available, data_available, fields = collection.find(
        backend, params
    )

    if more_data_available:
        raise StarletteHTTPException(
            status_code=500,
            detail=f"more_data_available MUST be False for single entry response, however it is {more_data_available}",
        )

    links = ToplevelLinks(next=None)

    if fields and results is not None:
        results = handle_response_fields(results, fields)[0]

    return response(
        links=links,
        data=results,
        meta=meta_values(
            str(request.url), data_returned, data_available, more_data_available
        ),
    )
예제 #5
0
def get_entries(
    collection: EntryCollection,
    response: EntryResponseMany,
    request: Request,
    params: EntryListingQueryParams,
) -> EntryResponseMany:
    """Generalized /{entry} endpoint getter"""
    from optimade.server.routers import ENTRY_COLLECTIONS

    results, data_returned, more_data_available, fields = collection.find(
        params)

    included = get_included_relationships(results, ENTRY_COLLECTIONS)

    if more_data_available:
        parse_result = urllib.parse.urlparse(str(request.url))
        query = urllib.parse.parse_qs(parse_result.query)
        query["page_offset"] = int(query.get("page_offset",
                                             [0])[0]) + len(results)
        urlencoded = urllib.parse.urlencode(query, doseq=True)
        links = ToplevelLinks(
            next=
            f"{parse_result.scheme}://{parse_result.netloc}{parse_result.path}?{urlencoded}"
        )
    else:
        links = ToplevelLinks(next=None)

    if fields:
        results = handle_response_fields(results, fields)

    return response(
        links=links,
        data=results,
        meta=meta_values(
            url=str(request.url),
            data_returned=data_returned,
            data_available=len(collection),
            more_data_available=more_data_available,
        ),
        included=included,
    )
예제 #6
0
async def get_entries(
    collection: AsyncMongoCollection,
    response_cls: "EntryResponseMany",
    request: "Request",
    params: "EntryListingQueryParams",
) -> "EntryResponseMany":
    """Generalized `/{entries}` endpoint getter"""
    (
        results,
        data_returned,
        more_data_available,
        fields,
        include_fields,
    ) = await collection.afind(params=params)

    if more_data_available:
        # Deduce the `next` link from the current request
        query = urllib.parse.parse_qs(request.url.query)
        query["page_offset"] = [int(query.get("page_offset", [0])[0]) + len(results)]  # type: ignore[list-item, arg-type]
        urlencoded = urllib.parse.urlencode(query, doseq=True)
        base_url = get_base_url(request.url)

        links = ToplevelLinks(next=f"{base_url}{request.url.path}?{urlencoded}")
    else:
        links = ToplevelLinks(next=None)

    if fields or include_fields:
        results = handle_response_fields(results, fields, include_fields)

    return response_cls(
        links=links,
        data=results,
        meta=meta_values(
            url=request.url,
            data_returned=data_returned,
            data_available=await collection.acount(),
            more_data_available=more_data_available,
        ),
    )
예제 #7
0
def get_single_entry(
    collection: EntryCollection,
    entry_id: str,
    response: EntryResponseOne,
    request: Request,
    params: SingleEntryQueryParams,
) -> EntryResponseOne:
    from optimade.server.routers import ENTRY_COLLECTIONS

    params.filter = f'id="{entry_id}"'
    (
        results,
        data_returned,
        more_data_available,
        fields,
        include_fields,
    ) = collection.find(params)

    include = []
    if getattr(params, "include", False):
        include.extend(params.include.split(","))
    included = get_included_relationships(results, ENTRY_COLLECTIONS, include)

    if more_data_available:
        raise HTTPException(
            status_code=500,
            detail=
            f"more_data_available MUST be False for single entry response, however it is {more_data_available}",
        )

    links = ToplevelLinks(next=None)

    if fields or include_fields and results is not None:
        results = handle_response_fields(results, fields, include_fields)[0]

    return response(
        links=links,
        data=results,
        meta=meta_values(
            url=request.url,
            data_returned=data_returned,
            data_available=len(collection),
            more_data_available=more_data_available,
        ),
        included=included,
    )
예제 #8
0
async def get_gateway(request: Request,
                      gateway_id: str) -> GatewaysResponseSingle:
    """`GET /gateways/{gateway ID}`

    Return a single [`GatewayResource`][optimade_gateway.models.gateways.GatewayResource].
    """
    collection = await collection_factory(CONFIG.gateways_collection)
    result = await get_valid_resource(collection, gateway_id)

    return GatewaysResponseSingle(
        links=ToplevelLinks(next=None),
        data=result,
        meta=meta_values(
            url=request.url,
            data_returned=1,
            data_available=await collection.acount(),
            more_data_available=False,
        ),
    )
예제 #9
0
async def post_gateways(request: Request,
                        gateway: GatewayCreate) -> GatewaysResponseSingle:
    """`POST /gateways`

    Create or return existing gateway according to `gateway`.
    """
    if gateway.database_ids:
        databases_collection = await collection_factory(
            CONFIG.databases_collection)

        databases = await databases_collection.get_multiple(filter={
            "id": {
                "$in": await clean_python_types(gateway.database_ids)
            }
        })

        if not isinstance(gateway.databases, list):
            gateway.databases = []

        current_database_ids = [_.id for _ in gateway.databases]
        gateway.databases.extend(
            (_ for _ in databases if _.id not in current_database_ids))

    result, created = await resource_factory(gateway)
    collection = await collection_factory(CONFIG.gateways_collection)

    return GatewaysResponseSingle(
        links=ToplevelLinks(next=None),
        data=result,
        meta=meta_values(
            url=request.url,
            data_returned=1,
            data_available=await collection.acount(),
            more_data_available=False,
            **{f"_{CONFIG.provider.prefix}_created": created},
        ),
    )
예제 #10
0
async def get_query(
    request: Request,
    query_id: str,
    response: Response,
) -> QueriesResponseSingle:
    """`GET /queries/{query_id}`

    Return a single [`QueryResource`][optimade_gateway.models.queries.QueryResource].
    """
    collection = await collection_factory(CONFIG.queries_collection)
    query: QueryResource = await get_valid_resource(collection, query_id)

    if query.attributes.response and query.attributes.response.errors:
        for error in query.attributes.response.errors:
            if error.status:
                for part in error.status.split(" "):
                    try:
                        response.status_code = int(part)
                        break
                    except ValueError:
                        pass
                if response.status_code and response.status_code >= 300:
                    break
        else:
            response.status_code = 500

    return QueriesResponseSingle(
        links=ToplevelLinks(next=None),
        data=query,
        meta=meta_values(
            url=request.url,
            data_returned=1,
            data_available=await collection.acount(),
            more_data_available=False,
        ),
    )
예제 #11
0
def get_single_entry(
    collection: AiidaCollection,
    entry_id: str,
    response: EntryResponseOne,
    request: Request,
    params: SingleEntryQueryParams,
) -> EntryResponseOne:
    """Generalized /{entry}/{entry_id} endpoint getter"""
    params.filter = f"id={entry_id}"
    (
        results,
        data_returned,
        more_data_available,
        data_available,
        fields,
    ) = collection.find(params)

    if more_data_available:
        raise HTTPException(
            status_code=500,
            detail=
            "more_data_available MUST be False for single entry response, "
            f"however it is {more_data_available}",
        )

    links = ToplevelLinks(next=None)

    if fields and results is not None:
        results = handle_response_fields(results, fields, collection)[0]

    return response(
        links=links,
        data=results,
        meta=meta_values(str(request.url), data_returned, data_available,
                         more_data_available),
    )
async def perform_query(
    url: "URL",
    query: "QueryResource",
) -> "Union[EntryResponseMany, ErrorResponse, GatewayQueryResponse]":
    """Perform OPTIMADE query with gateway.

    Parameters:
        url: Original request URL.
        query: The query to be performed.

    Returns:
        This function returns the final response; a
        [`GatewayQueryResponse`][optimade_gateway.models.queries.GatewayQueryResponse].

    """
    await update_query(query, "state", QueryState.STARTED)

    gateway: GatewayResource = await get_valid_resource(
        await collection_factory(CONFIG.gateways_collection),
        query.attributes.gateway_id,
    )

    filter_queries = await prepare_query_filter(
        database_ids=[_.id for _ in gateway.attributes.databases],
        filter_query=query.attributes.query_parameters.filter,
    )

    url = url.replace(path=f"{url.path.rstrip('/')}/{query.id}")
    await update_query(
        query,
        "response",
        GatewayQueryResponse(
            data={},
            links=ToplevelLinks(next=None),
            meta=meta_values(
                url=url,
                data_available=0,
                data_returned=0,
                more_data_available=False,
            ),
        ),
        operator=None,
        **{"$set": {
            "state": QueryState.IN_PROGRESS
        }},
    )

    loop = asyncio.get_running_loop()
    with ThreadPoolExecutor(
            max_workers=min(32, (os.cpu_count() or 0) +
                            4, len(gateway.attributes.databases))) as executor:
        # Run OPTIMADE DB queries in a thread pool, i.e., not using the main OS thread,
        # where the asyncio event loop is running.
        query_tasks = []
        for database in gateway.attributes.databases:
            query_params = await get_query_params(
                query_parameters=query.attributes.query_parameters,
                database_id=database.id,
                filter_mapping=filter_queries,
            )
            query_tasks.append(
                loop.run_in_executor(
                    executor=executor,
                    func=functools.partial(
                        db_find,
                        database=database,
                        endpoint=query.attributes.endpoint.value,
                        response_model=query.attributes.endpoint.
                        get_response_model(),
                        query_params=query_params,
                    ),
                ))

        for query_task in query_tasks:
            (db_response, db_id) = await query_task

            await process_db_response(
                response=db_response,
                database_id=db_id,
                query=query,
                gateway=gateway,
            )

    # Pagination
    #
    # if isinstance(results, list) and get_resource_attribute(
    #     query,
    #     "attributes.response.meta.more_data_available",
    #     False,
    #     disambiguate=False,  # Extremely minor speed-up
    # ):
    #     # Deduce the `next` link from the current request
    #     query_string = urllib.parse.parse_qs(url.query)
    #     query_string["page_offset"] = [
    #         int(query_string.get("page_offset", [0])[0])  # type: ignore[list-item]
    #         + len(results[: query.attributes.query_parameters.page_limit])
    #     ]
    #     urlencoded = urllib.parse.urlencode(query_string, doseq=True)
    #     base_url = get_base_url(url)

    #     links = ToplevelLinks(next=f"{base_url}{url.path}?{urlencoded}")

    #     await update_query(query, "response.links", links)

    await update_query(query, "state", QueryState.FINISHED)
    return query.attributes.response
예제 #13
0
async def post_search(request: Request,
                      search: Search) -> QueriesResponseSingle:
    """`POST /search`

    Coordinate a new OPTIMADE query in multiple databases through a gateway:

    1. Search for gateway in DB using `optimade_urls` and `database_ids`
    1. Create [`GatewayCreate`][optimade_gateway.models.gateways.GatewayCreate] model
    1. `POST` gateway resource to get ID - using functionality of `POST /gateways`
    1. Create new [Query][optimade_gateway.models.queries.QueryCreate] resource
    1. `POST` Query resource - using functionality of `POST /queries`
    1. Return `POST /queries` response -
        [`QueriesResponseSingle`][optimade_gateway.models.responses.QueriesResponseSingle]

    """
    databases_collection = await collection_factory(CONFIG.databases_collection
                                                    )
    # NOTE: It may be that the final list of base URLs (`base_urls`) contains the same
    # provider(s), but with differring base URLS, if, for example, a versioned base URL
    # is supplied.
    base_urls = set()

    if search.database_ids:
        databases = await databases_collection.get_multiple(filter={
            "id": {
                "$in": await clean_python_types(search.database_ids)
            }
        })
        base_urls |= {
            get_resource_attribute(database, "attributes.base_url")
            for database in databases if get_resource_attribute(
                database, "attributes.base_url") is not None
        }

    if search.optimade_urls:
        base_urls |= {_ for _ in search.optimade_urls if _ is not None}

    if not base_urls:
        msg = "No (valid) OPTIMADE URLs with:"
        if search.database_ids:
            msg += (
                f"\n  Database IDs: {search.database_ids} and corresponding found URLs: "
                f"{[get_resource_attribute(database, 'attributes.base_url') for database in databases]}"
            )
        if search.optimade_urls:
            msg += f"\n  Passed OPTIMADE URLs: {search.optimade_urls}"
        raise BadRequest(detail=msg)

    # Ensure all URLs are `pydantic.AnyUrl`s
    if not all(isinstance(_, AnyUrl) for _ in base_urls):
        raise InternalServerError(
            "Could unexpectedly not validate all base URLs as proper URLs.")

    databases = await databases_collection.get_multiple(
        filter={"base_url": {
            "$in": await clean_python_types(base_urls)
        }})
    if len(databases) == len(base_urls):
        # At this point it is expected that the list of databases in `databases`
        # is a complete set of databases requested.
        gateway = GatewayCreate(databases=databases)
    elif len(databases) < len(base_urls):
        # There are unregistered databases
        current_base_urls = {
            get_resource_attribute(database, "attributes.base_url")
            for database in databases
        }
        databases.extend([
            LinksResource(
                id=(f"{url.user + '@' if url.user else ''}{url.host}"
                    f"{':' + url.port if url.port else ''}"
                    f"{url.path.rstrip('/') if url.path else ''}").replace(
                        ".", "__"),
                type="links",
                attributes=LinksResourceAttributes(
                    name=(f"{url.user + '@' if url.user else ''}{url.host}"
                          f"{':' + url.port if url.port else ''}"
                          f"{url.path.rstrip('/') if url.path else ''}"),
                    description="",
                    base_url=url,
                    link_type=LinkType.CHILD,
                    homepage=None,
                ),
            ) for url in base_urls - current_base_urls
        ])
    else:
        LOGGER.error(
            "Found more database entries in MongoDB than then number of passed base URLs."
            " This suggests ambiguity in the base URLs of databases stored in MongoDB.\n"
            "  base_urls: %s\n  databases %s",
            base_urls,
            databases,
        )
        raise InternalServerError(
            "Unambiguous base URLs. See logs for more details.")

    gateway = GatewayCreate(databases=databases)
    gateway, created = await resource_factory(gateway)

    if created:
        LOGGER.debug("A new gateway was created for a query (id=%r)",
                     gateway.id)
    else:
        LOGGER.debug("A gateway was found and reused for a query (id=%r)",
                     gateway.id)

    query = QueryCreate(
        endpoint=search.endpoint,
        gateway_id=gateway.id,
        query_parameters=search.query_parameters,
    )
    query, created = await resource_factory(query)

    if created:
        asyncio.create_task(perform_query(url=request.url, query=query))

    collection = await collection_factory(CONFIG.queries_collection)

    return QueriesResponseSingle(
        links=ToplevelLinks(next=None),
        data=query,
        meta=meta_values(
            url=request.url,
            data_returned=1,
            data_available=await collection.acount(),
            more_data_available=False,
            **{f"_{CONFIG.provider.prefix}_created": created},
        ),
    )
예제 #14
0
async def get_search(
    request: Request,
    response: Response,
    search_params: SearchQueryParams = Depends(),
    entry_params: EntryListingQueryParams = Depends(),
) -> Union[QueriesResponseSingle, EntryResponseMany, ErrorResponse,
           RedirectResponse]:
    """`GET /search`

    Coordinate a new OPTIMADE query in multiple databases through a gateway:

    1. Create a [`Search`][optimade_gateway.models.search.Search] `POST` data - calling
        `POST /search`.
    1. Wait [`search_params.timeout`][optimade_gateway.queries.params.SearchQueryParams]
        seconds before returning the query, if it has not finished before.
    1. Return query - similar to `GET /queries/{query_id}`.

    This endpoint works similarly to `GET /queries/{query_id}`, where one passes the query
    parameters directly in the URL, instead of first POSTing a query and then going to its
    URL. Hence, a
    [`QueryResponseSingle`][optimade_gateway.models.responses.QueriesResponseSingle] is
    the standard response model for this endpoint.

    If the timeout time is reached and the query has not yet finished, the user is
    redirected to the specific URL for the query.

    If the `as_optimade` query parameter is `True`, the response will be parseable as a
    standard OPTIMADE entry listing endpoint like, e.g., `/structures`.
    For more information see the
    [OPTIMADE specification](https://github.com/Materials-Consortia/OPTIMADE/blob/master/optimade.rst#entry-listing-endpoints).

    """
    try:
        search = Search(
            query_parameters=OptimadeQueryParameters(
                **{
                    field: getattr(entry_params, field)
                    for field in OptimadeQueryParameters.__fields__
                    if getattr(entry_params, field)
                }),
            optimade_urls=search_params.optimade_urls,
            endpoint=search_params.endpoint,
            database_ids=search_params.database_ids,
        )
    except ValidationError as exc:
        raise BadRequest(detail=(
            "A Search object could not be created from the given URL query "
            f"parameters. Error(s): {exc.errors}")) from exc

    queries_response = await post_search(request, search=search)

    if not queries_response.data:
        LOGGER.error("QueryResource not found in POST /search response:\n%s",
                     queries_response)
        raise RuntimeError(
            "Expected the response from POST /search to return a QueryResource, it did "
            "not")

    once = True
    start_time = time()
    while (  # pylint: disable=too-many-nested-blocks
            time() < (start_time + search_params.timeout) or once):
        # Make sure to run this at least once (e.g., if timeout=0)
        once = False

        collection = await collection_factory(CONFIG.queries_collection)

        query: QueryResource = await collection.get_one(
            **{"filter": {
                "id": queries_response.data.id
            }})

        if query.attributes.state == QueryState.FINISHED:
            if query.attributes.response and query.attributes.response.errors:
                for error in query.attributes.response.errors:
                    if error.status:
                        for part in error.status.split(" "):
                            try:
                                response.status_code = int(part)
                                break
                            except ValueError:
                                pass
                        if response.status_code and response.status_code >= 300:
                            break
                else:
                    response.status_code = 500

            if search_params.as_optimade:
                return await query.response_as_optimade(url=request.url)

            return QueriesResponseSingle(
                links=ToplevelLinks(next=None),
                data=query,
                meta=meta_values(
                    url=request.url,
                    data_returned=1,
                    data_available=await collection.acount(),
                    more_data_available=False,
                ),
            )

        await asyncio.sleep(0.1)

    # The query has not yet succeeded and we're past the timeout time -> Redirect to
    # /queries/<id>
    return RedirectResponse(query.links.self)