Пример #1
0
    async def _query(
        self,
        query: Query,
        priority: float = 0,  # lowest goes first
        timeout: Optional[float] = 60.0,
    ) -> ReasonerResponse:
        """Queue up a query for batching and return when completed"""
        if self.worker is None:
            raise RuntimeError(
                "Cannot send a request until a worker is running - enter the context"
            )

        # TODO figure out a way to remove this conversion
        query = Query.parse_obj(query)

        response_queue = asyncio.Queue()

        qgraphs = get_canonical_qgraphs(query.message.query_graph)

        for qgraph in qgraphs:

            subquery = Query(message=Message(query_graph=qgraph))

            # Queue query for processing
            request_id = str(uuid.uuid1())
            await self.request_queue.put(
                (
                    (priority, next(self.counter)),
                    (request_id, subquery, response_queue),
                )
            )

        combined_output = ReasonerResponse.parse_obj(
            {
                "message": {
                    "knowledge_graph": {"nodes": {}, "edges": {}},
                    "results": [],
                }
            }
        )

        for _ in qgraphs:
            # Wait for response
            output: Union[ReasonerResponse, Exception] = await asyncio.wait_for(
                response_queue.get(),
                timeout=None,
            )

            if isinstance(output, Exception):
                raise output

            output.message.query_graph = None
            combined_output.message.update(output.message)

        combined_output.message.query_graph = query.message.query_graph

        return combined_output.dict()
Пример #2
0
async def subservice_callback(response: PDResponse,  guid: str) -> int:
    """
    Receives asynchronous message requests from an ARAGORN subservice callback

    :param response:
    :param guid:
    :return:
    """
    # init the return html status code
    ret_val: int = 200

    logger.info(f'{guid}: Receiving sub-service callback')
    # logger.debug(f'{guid}: The sub-service response: {response.json()}')

    # init the connection
    connection = None

    try:
        # create a connection to the queue
        connection = await aio_pika.connect_robust(f"amqp://{q_username}:{q_password}@{q_host}/")

        # with the connection post to the queue
        async with connection:
            # get a channel to the queue
            channel = await connection.channel()

            # create a file path/name
            file_name = f'{queue_file_dir}/{guid}-async-data.json'

            # save the response data to a file
            with open(file_name, 'w') as data_file:
                data_file.write(response.json())

            # publish what was received for the sub-service. post the file name for the queue handler
            publish_val = await channel.default_exchange.publish(aio_pika.Message(body=file_name.encode()), routing_key=guid)

            if isinstance(publish_val, spec.Basic.Ack):
                logger.info(f'{guid}: Callback message published to queue.')
            else:
                # set the html error code
                ret_val = 422

                logger.error(f'{guid}: Callback message publishing to queue failed, type: {type(publish_val)}')

    except Exception as e:
        logger.exception(f'Exception detected while handling sub-service callback using guid {guid}', e)

        # set the html status code
        ret_val = 500
    finally:
        # close the connection to the queue if it exists
        if connection:
            await connection.close()

    # return the response code
    return ret_val
Пример #3
0
async def lookup(
    request: ReasonerQuery = Body(..., example=load_example("query")),
) -> Response:
    """Look up answers to the question."""
    trapi_query = request.dict(
        by_alias=True,
        exclude_unset=True,
    )
    try:
        await map_identifiers(trapi_query)
    except KeyError:
        pass
    async with httpx.AsyncClient() as client:
        response = await client.post(
            f"{settings.robokop_kg}/query",
            json=trapi_query,
            timeout=None,
        )
        if response.status_code != 200:
            raise HTTPException(500, f"Failed doing lookup: {response.text}")

        response = await client.post(
            f"{settings.aragorn_ranker}/omnicorp_overlay",
            json=response.json(),
            timeout=None,
        )
        if response.status_code != 200:
            raise HTTPException(500, f"Failed doing overlay: {response.text}")

        response = await client.post(
            f"{settings.aragorn_ranker}/weight_correctness",
            json=response.json(),
            timeout=None,
        )
        if response.status_code != 200:
            raise HTTPException(500,
                                f"Failed doing weighting: {response.text}")

        response = await client.post(
            f"{settings.aragorn_ranker}/score",
            json=response.json(),
            timeout=None,
        )
        if response.status_code != 200:
            raise HTTPException(500, f"Failed doing scoring: {response.text}")
    return Response(**response.json())
Пример #4
0
async def query(request: PDResponse, *, jaccard_like: bool = False):
    """Score answers.

    This is mostly glue around the heavy lifting in ranker_obj.Ranker
    """
    # get the message into a dict
    in_message = request.dict()

    # save the logs for the response (if any)
    if 'logs' not in in_message or in_message['logs'] is None:
        in_message['logs'] = []

    # init the status code
    status_code: int = 200

    # get a reference to the entire message
    message = in_message['message']

    # get a reference to the results
    answers = message['results']

    try:
        # resistance distance ranking
        pr = Ranker(message)

        # rank the answers. there should be a score for each bound result after this
        answers = pr.rank(answers, jaccard_like=jaccard_like)

        # save the results
        message['results'] = answers
    except Exception as e:
        # put the error in the response
        status_code = 500

        # save any log entries
        in_message['logs'].append(
            create_log_entry(f'Exception: {str(e)}', 'ERROR'))

    # validate the response and get it into json
    in_message = jsonable_encoder(PDResponse(**in_message))

    # return the result to the caller
    return JSONResponse(content=in_message, status_code=status_code)
Пример #5
0
    async def answer_question(query: Query, ) -> Response:
        """Get results for query graph."""
        query = query.dict(exclude_unset=True)
        workflow = query.get("workflow", [{"id": "lookup"}])
        if len(workflow) > 1:
            raise HTTPException(
                400, "Binder does not support workflows of length >1")
        operation = workflow[0]
        qgraph = query["message"]["query_graph"]
        if operation["id"] == "lookup":
            async with KnowledgeProvider(database_file, **kwargs) as kp:
                kgraph, results = await kp.get_results(qgraph)
        elif operation["id"] == "bind":
            kgraph = query["message"]["knowledge_graph"]
            knodes = [{
                "id":
                knode_id,
                "category":
                knode.get("categories", ["biolink:NamedThing"])[0],
            } for knode_id, knode in kgraph["nodes"].items()]
            kedges = [{
                "id": kedge_id,
                "subject": kedge["subject"],
                "predicate": kedge["predicate"],
                "object": kedge["object"],
            } for kedge_id, kedge in kgraph["edges"].items()]

            async with KnowledgeProvider(":memory:", **kwargs) as kp:
                await add_data(kp.db, knodes, kedges)
                kgraph, results = await kp.get_results(qgraph)
        else:
            raise HTTPException(400, f"Unsupported operation {operation}")

        response = {
            "message": {
                "knowledge_graph": kgraph,
                "results": results,
                "query_graph": qgraph,
            }
        }
        return Response.parse_obj(response)
Пример #6
0
async def run_workflow(
    request: ReasonerQuery = Body(..., example=load_example("query")),
) -> Response:
    """Run workflow."""
    request_dict = request.dict(exclude_unset=True, )

    message = request_dict["message"]
    workflow = request_dict["workflow"]
    logger = gen_logger()
    qgraph = message["query_graph"]
    kgraph = {"nodes": {}, "edges": {}}
    if "knowledge_graph" in message.keys():
        if "nodes" in message["knowledge_graph"].keys():
            kgraph = message["knowledge_graph"]

    async with httpx.AsyncClient(verify=False, timeout=60.0) as client:
        for operation in workflow:
            service_operation_responses = []
            for service in SERVICES[operation["id"]]:
                url = service["url"]
                service_name = service["title"]
                logger.debug(
                    f"Requesting operation '{operation}' from {service_name}..."
                )
                try:
                    response = await post_safely(
                        url,
                        {
                            "message": message,
                            "workflow": [
                                operation,
                            ],
                            "submitter": "Workflow Runner",
                        },
                        client=client,
                        timeout=60.0,
                        logger=logger,
                        service_name=service_name,
                    )
                    logger.debug(
                        f"Received operation '{operation}' from {service_name}..."
                    )

                    try:
                        response = await post_safely(
                            NORMALIZER_URL + "/response", {
                                "message": response["message"],
                                "submitter": "Workflow Runner"
                            },
                            client=client,
                            timeout=60.0,
                            logger=logger,
                            service_name="node_normalizer")
                    except RuntimeError as e:
                        logger.warning({"error": str(e)})

                    service_operation_responses.append(response)

                except RuntimeError as e:
                    logger.warning({"error": str(e)})

                if not OPERATIONS[operation["id"]]["unique"] and len(
                        service_operation_responses) == 1:
                    # We only need one successful response for non-unique operations
                    break

            logger.debug(
                f"Merging {len(service_operation_responses)} responses for '{operation}'..."
            )
            m = Message(
                query_graph=QueryGraph.parse_obj(qgraph),
                knowledge_graph=KnowledgeGraph.parse_obj(kgraph),
            )
            for response in service_operation_responses:
                response["message"]["query_graph"] = qgraph
                m.update(Message.parse_obj(response["message"]))
            message = m.dict()

    return Response(
        message=message,
        workflow=workflow,
        logs=logger.handlers[0].store,
    )
async def query(
        request: PDResponse,
        relevance: Optional[float] = Query(
            0.0025,
            description='Portion of cooccurrence publications relevant to a question',
        ),
        wt_min: Optional[float] = Query(
            0.0,
            description='Minimum weight (at 0 publications)',
        ),
        wt_max: Optional[float] = Query(
            1.0,
            description='Maximum weight (at inf publications)',
        ),
        p50: Optional[float] = Query(
            2.0,
            description='Publications at 50% of wt_max',
        ),
):
    """Weight kgraph edges based on metadata.

    "19 pubs from CTD is a 1, and 2 should at least be 0.5"
        - cbizon
    """
    in_message = request.dict()

    # save the logs for the response (if any)
    if 'logs' not in in_message or in_message['logs'] is None:
        in_message['logs'] = []

    # init the status code
    status_code: int = 200

    message = in_message['message']

    def sigmoid(x):
        """Scale with partial sigmoid - the right (concave down) half.

        Such that:
           f(0) = wt_min
           f(inf) = wt_max
           f(p50) = 0.5 * wt_max
        """
        a = 2 * (wt_max - wt_min)
        r = 0.5 * wt_max
        c = wt_max - 2 * wt_min
        k = 1 / p50 * (math.log(r + c) - math.log(a - r - c))
        return a / (1 + math.exp(-k * x)) - c

    def create_log_entry(msg: str, err_level, code=None) -> dict:
        # load the data
        ret_val = {
            'timestamp': str(datetime.now()),
            'level': err_level,
            'message': msg,
            'code': code
        }

        # return to the caller
        return ret_val

    try:
        # constant count of all publications
        all_pubs = 27840000

        # get the data nodes we need
        results = message['results']
        kgraph = message['knowledge_graph']

        # storage for the publication counts for the node
        node_pubs: dict = {}

        # for each node in the knowledge graph
        for n in kgraph['nodes']:
            # init the count value
            omnicorp_article_count: int = 0

            # get the article count atribute
            for p in kgraph['nodes'][n]['attributes']:
                # is this what we are looking for
                if p['original_attribute_name'] == 'omnicorp_article_count':
                    # save it
                    omnicorp_article_count = p['value']

                    # no need to continue
                    break

            # add the node d and count to the dict
            node_pubs.update({n: omnicorp_article_count})

        # map kedges to result edge bindings
        krmap = defaultdict(list)

        # for each result listed in the data get a map reference and default the weight attribute
        for result in results:
            # for every edge binding result
            for eb in result['edge_bindings']:
                # loop through the edge binding
                for idx, binding_val in enumerate(result['edge_bindings'][eb]):
                    # get a reference to the weight for easy update later
                    krmap[binding_val['id']] = result['edge_bindings'][eb][idx]

                    found = False

                    # is there already a list of attributes
                    if 'attributes' in krmap[binding_val['id']]:
                        # loop through the attributes
                        for item in krmap[binding_val['id']]['attributes']:
                            # search for the weight attribute
                            if item['original_attribute_name'].startswith('weight'):
                                found = True
                                break

                    # was the attribute found
                    if not found:
                        if 'attributes' not in krmap[binding_val['id']]:
                            krmap[binding_val['id']]['attributes'] = []

                        # create an Attribute
                        krmap[binding_val['id']]['attributes'].append({
                            'original_attribute_name': 'weight',
                            'attribute_type_id': 'biolink:has_numeric_value',
                            'value': 1,
                            'value_type_id': 'EDAM:data_1669'})

        # get the knowledge graph edges
        edges = kgraph['edges']

        # for each knowledge graph edge
        for edge in edges:
            # We are getting some results back (BTE?) that have "publications": ['PMID:1234|2345|83984']
            attributes = edges[edge].get('attributes', None)

            # init storage for the publications and their count
            publications = []
            num_publications = 0

            if attributes is not None:
                # for each data attribute collect the needed params
                for attribute in attributes:
                    if attribute['original_attribute_name'] is not None:
                        # is this the publication list
                        if attribute['original_attribute_name'].startswith('publications'):
                            publications = attribute['value']
                        # else is this the number of publications
                        elif attribute['original_attribute_name'].startswith('num_publications'):
                            num_publications = attribute.get('value', 0)

                # if there was only 1 publication value found insure it wasnt a character separated list
                if len(publications) == 1:
                    if '|' in publications[0]:
                        publications = publications[0].split('|')
                    elif ',' in publications[0]:
                        publications = publications[0].split(',')

                    # get the real publication count
                    num_publications = len(publications)

                # if there was no publication count found revert to the number of individual values
                if num_publications == 0:
                    num_publications = len(publications)

                # now the nicer cleaner version when we have publications as an actual array
                # edge_pubs = edge.get('num_publications', len(edge.get('publications', [])))
                if edges[edge].get('predicate') == 'literature_co-occurrence':
                    subject_pubs = int(node_pubs[edge['subject']])
                    object_pubs = int(node_pubs[edge['object']])

                    cov = (num_publications / all_pubs) - (subject_pubs / all_pubs) * (object_pubs / all_pubs)
                    cov = max((cov, 0.0))
                    effective_pubs = cov * all_pubs * relevance
                else:
                    effective_pubs = num_publications + 1  # consider the curation a pub

                # if there is something to add this new attribute to
                if len(krmap[edge]) != 0:
                    # is there already a list of attributes
                    if 'attributes' in krmap[edge]:
                        # loop through the attributes
                        for item in krmap[edge]['attributes']:
                            # search for the weight attribute
                            if item['original_attribute_name'].startswith('weight'):
                                # update the params
                                item['attribute_type_id'] = 'biolink:has_numeric_value'
                                item['value'] = item['value'] * sigmoid(effective_pubs)
                                item['value_type_id'] = 'EDAM:data_1669'
                                found = True
                                break

        # save the new knowledge graph data
        message['knowledge_graph'] = kgraph

    except Exception as e:
        # put the error in the response
        status_code = 500

        # save any log entries
        in_message['logs'].append(create_log_entry(f'Exception: {str(e)}', 'ERROR'))

    # validate the response again after normalization
    in_message = jsonable_encoder(PDResponse(**in_message))

    # return the result to the caller
    return JSONResponse(content=in_message, status_code=status_code)
Пример #8
0
async def coalesce_handler(request: PDResponse, method: MethodName):
    """ Answer coalesce operations. You may choose all, property, graph. """

    # convert the incoming message into a dict
    in_message = request.dict()

    # save the logs for the response (if any)
    if 'logs' not in in_message or in_message['logs'] is None:
        in_message['logs'] = []

    # these timestamps are causing json serialization issues in call to the normalizer
    # so here we convert them to strings.
    for log in in_message['logs']:
        log['timestamp'] = str(log['timestamp'])

    # make sure there are results to coalesce
    # 0 results is perfectly legal, there's just nothing to do.
    if 'results' not in in_message[
            'message'] or in_message['message']['results'] is None or len(
                in_message['message']['results']) == 0:
        status_code = 200
        logger.error(f"No results to coalesce")
        # in_message['logs'].append(create_log_entry(f'No results to coalesce', "WARNING"))
        return JSONResponse(content=in_message, status_code=status_code)

    elif 'knowledge_graph' not in in_message['message'] or in_message[
            'message']['knowledge_graph'] is None or len(
                in_message['message']['knowledge_graph']) == 0:
        # This is a 422 b/c we do have results, but there's no graph to use.
        status_code = 422
        logger.error(f"No knowledge graph to coalesce")
        # in_message['logs'].append(create_log_entry(f'No knowledge graph to coalesce', "ERROR"))
        return JSONResponse(content=in_message, status_code=status_code)

    # init the status code
    status_code: int = 200

    # get the message to work on
    coalesced = in_message['message']

    try:
        # call the operation with the message in the request message
        coalesced = coalesce(coalesced, method=method)

        # turn it back into a full trapi message
        in_message['message'] = coalesced

        # import json
        # with open('ac_out_attributes.json', 'w') as tf:
        #     tf.write(json.dumps(in_message, default=str))

        # # Normalize the data
        # coalesced = normalize(in_message)
        #
        # # save the response in the incoming message
        # in_message['message'] = coalesced['message']

    except Exception as e:
        # put the error in the response
        status_code = 500
        logger.exception(f"Exception encountered {str(e)}")
        # in_message['logs'].append(create_log_entry(f'Exception {str(e)}', "ERROR"))

    # return the result to the caller
    return JSONResponse(content=in_message, status_code=status_code)
async def query(request: PDResponse):
    """Add support to message.

    Add support edges to knowledge_graph and bindings to results.
    """
    in_message = request.dict()

    # save the logs for the response (if any)
    if 'logs' not in in_message or in_message['logs'] is None:
        in_message['logs'] = []

    # init the status code
    status_code: int = 200

    message = in_message['message']

    qgraph = message['query_graph']
    kgraph = message['knowledge_graph']
    answers = message['results']

    # get cache if possible
    try:
        cache = Cache(
            redis_host=CACHE_HOST,
            redis_port=CACHE_PORT,
            redis_db=CACHE_DB,
            redis_password=CACHE_PASSWORD,
        )
    except Exception as e:
        logger.exception(e)
        cache = None

    redis_batch_size = 100

    try:
        async with OmnicorpSupport() as supporter:
            # get all node supports

            keys = [
                f"{supporter.__class__.__name__}({node})"
                for node in kgraph['nodes']
            ]
            values = []
            for batch in batches(keys, redis_batch_size):
                values.extend(cache.mget(*batch))

            jobs = [
                count_node_pmids(supporter, node, key, value, cache,
                                 kgraph['nodes'])
                for node, value, key in zip(kgraph['nodes'], values, keys)
            ]

            # which qgraph nodes are sets?
            qgraph_setnodes = set([
                n for n in qgraph['nodes']
                if (('is_set' in qgraph['nodes'][n]) and (
                    qgraph['nodes'][n]['is_set']))
            ])

            # Generate a set of pairs of node curies
            pair_to_answer = defaultdict(set)  # a map of node pairs to answers
            for ans_idx, answer_map in enumerate(answers):

                # Get all nodes that are not part of sets and densely connect them
                # can be str (not a set) or list (could be a set or not a set)
                nonset_nodes = []
                setnodes = {}

                # node binding results is now a dict containing dicts that contain a list of dicts.
                for nb in answer_map['node_bindings']:
                    if nb in qgraph_setnodes:
                        setnodes[nb] = [
                            node['id']
                            for node in answer_map['node_bindings'][nb]
                        ]
                    else:
                        if len(answer_map['node_bindings'][nb]) != 0:
                            nonset_nodes.append(
                                answer_map['node_bindings'][nb][0]['id'])

                nonset_nodes = sorted(nonset_nodes)
                # nodes = sorted([nb['kg_id'] for nb in answer_map['node_bindings'] if isinstance(nb['kg_id'], str)])
                for node_pair in combinations(nonset_nodes, 2):
                    pair_to_answer[node_pair].add(ans_idx)

                # set_nodes_list_list = [nb['kg_id'] for nb in answer_map['node_bindings'] if isinstance(nb['kg_id'], list)]
                # set_nodes = [n for el in set_nodes_list_list for n in el]
                # For all nodes that are within sets, connect them to all nodes that are not in sets
                for qg_id, snodes in setnodes.items():
                    for snode in snodes:
                        for node in nonset_nodes:
                            node_pair = tuple(sorted((node, snode)))
                            pair_to_answer[node_pair].add(ans_idx)

                # now all nodes in set a to all nodes in set b
                for qga, qgb in combinations(setnodes.keys(), 2):
                    for anode in setnodes[qga]:
                        for bnode in setnodes[qgb]:
                            # node_pair = tuple(sorted(anode, bnode))
                            node_pair = tuple(sorted((anode, bnode)))
                            pair_to_answer[node_pair].add(ans_idx)

            # get all pair supports
            cached_prefixes = cache.get('OmnicorpPrefixes') if cache else None

            keys = [
                f"{supporter.__class__.__name__}_count({pair[0]},{pair[1]})"
                for pair in pair_to_answer
            ]
            values = []
            for batch in batches(keys, redis_batch_size):
                values.extend(cache.mget(*batch))

            jobs.extend([
                count_shared_pmids(
                    supporter,
                    support_idx,
                    pair,
                    key,
                    value,
                    cache,
                    cached_prefixes,
                    kgraph,
                    pair_to_answer,
                    answers,
                ) for support_idx, (
                    pair, value,
                    key) in enumerate(zip(pair_to_answer, values, keys))
            ])
            await asyncio.gather(*jobs)

        # load the new results into the response
        message['knowledge_graph'] = kgraph
        message['results'] = answers

    except Exception as e:
        # put the error in the response
        status_code = 500

        # save any log entries
        in_message['logs'].append(
            create_log_entry(f'Exception: {str(e)}', 'ERROR'))

    # validate the response again after normalization
    in_message = jsonable_encoder(PDResponse(**in_message))

    # return the result to the caller
    return JSONResponse(content=in_message, status_code=status_code)
Пример #10
0
async def query(response: Response, *, exclude_sets=False) -> Response:
    """Compute informativeness weights for edges."""
    message = response.message.dict()
    qgraph = message['query_graph']
    results = message['results']

    qnodes = qgraph['nodes']
    qedges = qgraph['edges']

    # knode_map = {knode['id']: knode for knode in knodes}
    qnode_map = {qnode['id']: qnode for qnode in qnodes}
    qedge_map = {qedge['id']: qedge for qedge in qedges}

    driver = Neo4jDatabase(
        url=NEO4J_URL,
        credentials={
            'username': NEO4J_USER,
            'password': NEO4J_PASSWORD,
        },
    )
    redges_by_id = dict()
    count_plans = defaultdict(lambda: defaultdict(list))
    for kdx, result in enumerate(results):
        rgraph = get_rgraph(result, message)
        redges_by_id.update({(kdx, redge['id']): redge
                             for redge in rgraph['edges']})

        for redge in rgraph['edges']:
            if (not exclude_sets) or qnode_map[redge['qg_target_id']].get(
                    'set', False):
                count_plans[redge['kg_source_id']][(
                    redge['eb']['qg_id'], redge['qg_target_id'])].append(
                        (kdx, redge['id']))
            if (not exclude_sets) or qnode_map[redge['qg_source_id']].get(
                    'set', False):
                count_plans[redge['kg_target_id']][(
                    redge['eb']['qg_id'], redge['qg_source_id'])].append(
                        (kdx, redge['id']))

    count_to_redge = {}
    for ldx, batch in enumerate(batches(list(count_plans.keys()), 1000)):
        batch_bits = []
        for idx, ksource_id in enumerate(batch):
            sets = []
            plan = count_plans[ksource_id]
            anchor_node_reference = NodeReference({
                'id': f'n{idx:04d}',
                'curie': ksource_id,
                'type': 'named_thing'
            })
            anchor_node_reference = str(anchor_node_reference)

            # base = f"MATCH ({anchor_node_reference}) "

            for jdx, (qlink, redge_ids) in enumerate(plan.items()):
                cypher_counts = []
                qedge_id, qtarget_id = qlink
                count_id = f"c{idx:03d}{chr(97 + jdx)}"
                qedge = qedge_map[qedge_id]
                edge_reference = EdgeReference(qedge, anonymous=True)
                anon_node_reference = NodeReference({
                    **qnode_map[qtarget_id],
                    'id':
                    count_id,
                })

                if qedge['source_id'] == qtarget_id:
                    source_reference = anon_node_reference
                    target_reference = anchor_node_reference
                elif qedge['target_id'] == qtarget_id:
                    source_reference = anchor_node_reference
                    target_reference = anon_node_reference

                # TODO: source_reference and target_reference referenced before assignent

                cypher_counts.append(
                    f"{anon_node_reference.name}: count(DISTINCT {anon_node_reference.name})"
                )
                count_to_redge[count_id] = redge_ids
                sets.append(
                    f'MATCH ({source_reference}){edge_reference}({target_reference})'
                    + ' RETURN {' + ', '.join(cypher_counts) + '} as output')
            batch_bits.append(' UNION ALL '.join(sets))
        cypher = ' UNION ALL '.join(batch_bits)
        response = driver.run(cypher)

        degrees = {
            key: value
            for result in response for key, value in result['output'].items()
        }

        for key in degrees:
            for redge_id in count_to_redge[key]:
                eb = redges_by_id[redge_id]['eb']
                eb['weight'] = eb.get('weight', 1.0) / degrees[key]

    message['results'] = results

    # get this in the correct response model format
    ret_val = {'message': message}

    # return the message back to the caller
    return Response(**ret_val)
Пример #11
0
async def _post_safely(
    client: httpx.AsyncClient,
    url: str,
    payload: Any,
    timeout: Optional[float] = None,
    logger: Optional[logging.Logger] = None,
    service_name: Optional[str] = None,
):
    if not logger:
        logger = logging.getLogger(__name__)
    if not service_name:
        service_name = url
    try:
        # use waitfor instead of httpx's timeout because: https://github.com/encode/httpx/issues/1451#issuecomment-907400740
        response = await client.post(
                url,
                json=payload,
            )
        response.raise_for_status()
        response_json = response.json()
        Response(**response_json)  # validate against TRAPI
        return response_json
    except asyncio.TimeoutError as e:
        logger.warning({
            "message": f"{service_name} took >{timeout} seconds to respond",
            "error": str(e),
            "request": {
                "url": url,
                "data": elide_curies(payload),
            },
        })
    except httpx.RequestError as e:
        # Log error
        logger.warning({
            "message": f"Error contacting {service_name}",
            "error": str(e),
            "request": log_request(e.request),
        })
    except httpx.HTTPStatusError as e:
        # Log error with response
        logger.warning({
            "message": f"Error response from {service_name}",
            "error": str(e),
            "request": log_request(e.request),
            "response": log_response(e.response),
        })
    except json.JSONDecodeError as e:
        # Log error with response
        logger.warning({
            "message": f"Received bad JSON data from {service_name}",
            "request": {
                "data": payload,
            },
            "response": {
                "data": e.doc
            },
            "error": str(e),
        })
    except pydantic.ValidationError as e:
        logger.warning({
            "message": f"Received non-TRAPI compliant response from {service_name}",
            "error": str(e),
        })
    except Exception as e:
        traceback.print_exc()
        logger.warning({
            "message": f"Something went wrong while querying {service_name}",
            "error": str(e),
        })
    raise RuntimeError(f"Failed to get a good response from {service_name}, see the logs")
Пример #12
0
async def normalize_response(response: Response) -> Response:
    """
    Normalizes a TRAPI compliant knowledge graph
    """
    response.message = await normalize_message(app, response.message)
    return response
Пример #13
0
    async def process_batch(self, ):
        """Set up a subscriber to process batching."""
        # Initialize the TAT
        #
        # TAT = Theoretical Arrival Time
        # When the next request should be sent
        # to adhere to the rate limit.
        #
        # This is an implementation of the GCRA algorithm
        # More information can be found here:
        # https://dev.to/astagi/rate-limiting-using-python-and-redis-58gk
        if self.request_qty > 0:
            interval = self.request_duration / self.request_qty
            tat = datetime.datetime.utcnow() + interval

        while True:
            # Get everything in the stream or wait for something to show up
            priority, (
                request_id,
                payload,
                response_queue,
            ) = await self.request_queue.get()
            priorities = {request_id: priority}
            request_value_mapping = {request_id: payload}
            response_queues = {request_id: response_queue}
            while True:
                if (self.max_batch_size is not None
                        and len(request_value_mapping) == self.max_batch_size):
                    break
                try:
                    priority, (
                        request_id,
                        payload,
                        response_queue,
                    ) = self.request_queue.get_nowait()
                except QueueEmpty:
                    break
                priorities[request_id] = priority
                request_value_mapping[request_id] = payload
                response_queues[request_id] = response_queue

            # Extract a curie mapping from each request
            request_curie_mapping = {
                request_id: get_curies(request_value["message"]["query_graph"])
                for request_id, request_value in request_value_mapping.items()
            }

            # Find requests that are the same (those that we can merge)
            # This disregards non-matching IDs because the IDs have been
            # removed with the extract_curie method
            stripped_qgraphs = {
                request_id: remove_curies(request["message"]["query_graph"])
                for request_id, request in request_value_mapping.items()
            }
            first_value = next(iter(stripped_qgraphs.values()))

            batch_request_ids = get_keys_with_value(
                stripped_qgraphs,
                first_value,
            )

            # Re-queue the un-selected requests
            for request_id in request_value_mapping:
                if request_id not in batch_request_ids:
                    await self.request_queue.put((
                        priorities[request_id],
                        (
                            request_id,
                            request_value_mapping[request_id],
                            response_queues[request_id],
                        ),
                    ))

            request_value_mapping = {
                k: v
                for k, v in request_value_mapping.items()
                if k in batch_request_ids
            }

            # Filter curie mapping to only include matching requests
            request_curie_mapping = {
                k: v
                for k, v in request_curie_mapping.items()
                if k in batch_request_ids
            }

            # Pull first value from request_value_mapping
            # to use as a template for our merged request
            merged_request_value = copy.deepcopy(
                next(iter(request_value_mapping.values())))

            # Remove qnode ids
            for qnode in merged_request_value["message"]["query_graph"][
                    "nodes"].values():
                qnode.pop("ids", None)

            # Update merged request using curie mapping
            for curie_mapping in request_curie_mapping.values():
                for node_id, node_curies in curie_mapping.items():
                    node = merged_request_value["message"]["query_graph"][
                        "nodes"][node_id]
                    if "ids" not in node:
                        node["ids"] = []
                    node["ids"].extend(node_curies)
            for qnode in merged_request_value["message"]["query_graph"][
                    "nodes"].values():
                if qnode.get("ids"):
                    qnode["ids"] = list(set(qnode["ids"]))

            response_values = dict()
            try:
                # Make request
                self.logger.info(
                    "[{id}] Sending request made of {subrequests} subrequests ({curies} curies)"
                    .format(
                        id=self.id,
                        subrequests=len(request_curie_mapping),
                        curies=" x ".join(
                            str(len(qnode.get("ids", []) or []))
                            for qnode in merged_request_value["message"]
                            ["query_graph"]["nodes"].values()),
                    ))
                self.logger.context = self.id
                merged_request_value = await self.preproc(
                    merged_request_value, self.logger)
                merged_request_value["submitter"] = "infores:aragorn"
                async with httpx.AsyncClient() as client:
                    response = await client.post(
                        self.url,
                        json=merged_request_value,
                        timeout=self.timeout,
                    )
                if response.status_code == 429:
                    # reset TAT
                    interval = self.request_duration / self.request_qty
                    tat = datetime.datetime.utcnow() + interval
                    # re-queue requests
                    for request_id in request_value_mapping:
                        await self.request_queue.put((
                            priorities[request_id],
                            (
                                request_id,
                                request_value_mapping[request_id],
                                response_queues[request_id],
                            ),
                        ))
                    # try again later
                    continue

                response.raise_for_status()

                # Parse with reasoner_pydantic to validate
                response_body = ReasonerResponse.parse_obj(
                    response.json()).dict()
                response_body = await self.postproc(response_body)
                message = response_body["message"]
                results = message.get("results") or []
                self.logger.info(
                    "[{}] Received response with {} results in {} seconds".
                    format(
                        self.id,
                        len(results),
                        response.elapsed.total_seconds(),
                    ))

                if len(request_curie_mapping) == 1:
                    request_id = next(iter(request_curie_mapping))
                    response_values[request_id] = {
                        "message": {
                            "query_graph":
                            request_value_mapping[request_id]["message"]
                            ["query_graph"],
                            "knowledge_graph":
                            message.get("knowledge_graph", {
                                "nodes": {},
                                "edges": {}
                            }),
                            "results":
                            message.get("results", []),
                        }
                    }
                else:
                    # Split using the request_curie_mapping
                    for request_id, curie_mapping in request_curie_mapping.items(
                    ):
                        try:
                            kgraph, results = filter_by_curie_mapping(
                                message, curie_mapping, kp_id=self.id)
                            response_values[request_id] = {
                                "message": {
                                    "query_graph":
                                    request_value_mapping[request_id]
                                    ["message"]["query_graph"],
                                    "knowledge_graph":
                                    kgraph,
                                    "results":
                                    results,
                                }
                            }
                        except BatchingError as err:
                            # the response is probably malformed
                            response_values[request_id] = err
            except (
                    asyncio.exceptions.TimeoutError,
                    httpx.RequestError,
                    httpx.HTTPStatusError,
                    JSONDecodeError,
                    pydantic.ValidationError,
            ) as e:
                for request_id, curie_mapping in request_curie_mapping.items():
                    response_values[request_id] = {
                        "message": {
                            "query_graph":
                            request_value_mapping[request_id]["message"]
                            ["query_graph"],
                            "knowledge_graph": {
                                "nodes": {},
                                "edges": {},
                            },
                            "results": [],
                        },
                    }
                if isinstance(e, asyncio.TimeoutError):
                    self.logger.warning({
                        "message":
                        f"{self.id} took >60 seconds to respond",
                        "error":
                        str(e),
                        "request":
                        elide_curies(merged_request_value),
                    })
                elif isinstance(e, httpx.ReadTimeout):
                    self.logger.warning({
                        "message": f"{self.id} took >60 seconds to respond",
                        "error": str(e),
                        "request": log_request(e.request),
                    })
                elif isinstance(e, httpx.RequestError):
                    # Log error
                    self.logger.warning({
                        "message": f"Request Error contacting {self.id}",
                        "error": str(e),
                        "request": log_request(e.request),
                    })
                elif isinstance(e, httpx.HTTPStatusError):
                    # Log error with response
                    self.logger.warning({
                        "message": f"Response Error contacting {self.id}",
                        "error": str(e),
                        "request": log_request(e.request),
                        "response": log_response(e.response),
                    })
                elif isinstance(e, JSONDecodeError):
                    # Log error with response
                    self.logger.warning({
                        "message": f"Received bad JSON data from {self.id}",
                        "request": e.request,
                        "response": e.response.text,
                        "error": str(e),
                    })
                elif isinstance(e, pydantic.ValidationError):
                    self.logger.warning({
                        "message":
                        f"Received non-TRAPI compliant response from {self.id}",
                        "error": str(e),
                    })
                else:
                    self.logger.warning({
                        "message":
                        f"Something went wrong while querying {self.id}",
                        "error": str(e),
                    })

            for request_id, response_value in response_values.items():
                # Write finished value to DB
                await response_queues[request_id].put(response_value)

            # if request_qty == 0 we don't enforce the rate limit
            if self.request_qty > 0:
                time_remaining_seconds = (
                    tat - datetime.datetime.utcnow()).total_seconds()

                # Wait for TAT
                if time_remaining_seconds > 0:
                    await asyncio.sleep(time_remaining_seconds)

                # Update TAT
                tat = datetime.datetime.utcnow() + interval