예제 #1
0
def filter_by_curie_mapping(
    message: Message,
    curie_mapping: dict[str, list[str]],
    kp_id: str = "KP",
) -> Message:
    """
    Filter a message to ensure that all results
    contain the bindings specified in the curie_mapping
    """
    # Only keep results where there is a node binding
    # that connects to our given kgraph_node_id
    results = [
        result for result in (message.get("results") or [])
        if result_contains_node_bindings(result, curie_mapping)
    ]

    # Construct result-specific knowledge graph
    kgraph = {
        "nodes": {
            binding["id"]: message["knowledge_graph"]["nodes"][binding["id"]]
            for result in results
            for _, bindings in result["node_bindings"].items()
            for binding in bindings
        },
        "edges": {
            binding["id"]: message["knowledge_graph"]["edges"][binding["id"]]
            for result in results
            for _, bindings in result["edge_bindings"].items()
            for binding in bindings
        },
    }

    return kgraph, results
예제 #2
0
    async def _query(
        self,
        query: Query,
        priority: float = 0,  # lowest goes first
        timeout: Optional[float] = 60.0,
    ) -> ReasonerResponse:
        """Queue up a query for batching and return when completed"""
        if self.worker is None:
            raise RuntimeError(
                "Cannot send a request until a worker is running - enter the context"
            )

        # TODO figure out a way to remove this conversion
        query = Query.parse_obj(query)

        response_queue = asyncio.Queue()

        qgraphs = get_canonical_qgraphs(query.message.query_graph)

        for qgraph in qgraphs:

            subquery = Query(message=Message(query_graph=qgraph))

            # Queue query for processing
            request_id = str(uuid.uuid1())
            await self.request_queue.put(
                (
                    (priority, next(self.counter)),
                    (request_id, subquery, response_queue),
                )
            )

        combined_output = ReasonerResponse.parse_obj(
            {
                "message": {
                    "knowledge_graph": {"nodes": {}, "edges": {}},
                    "results": [],
                }
            }
        )

        for _ in qgraphs:
            # Wait for response
            output: Union[ReasonerResponse, Exception] = await asyncio.wait_for(
                response_queue.get(),
                timeout=None,
            )

            if isinstance(output, Exception):
                raise output

            output.message.query_graph = None
            combined_output.message.update(output.message)

        combined_output.message.query_graph = query.message.query_graph

        return combined_output.dict()
예제 #3
0
def get_curies(message: Message) -> list[str]:
    """Get all node curies used in message.

    Do not examine kedge source and target ids. There ought to be corresponding
    knodes.
    """
    curies = set()
    if message.get("query_graph") is not None:
        for qnode in message["query_graph"]["nodes"].values():
            if qnode_id := qnode.get("ids", False):
                curies |= set(qnode_id)
예제 #4
0
def apply_curie_map(message: Message, curie_map: dict[str, str]) -> Message:
    """Translate all pinned qnodes to preferred prefix."""
    new_message = dict()
    new_message["query_graph"] = map_qgraph_curies(message["query_graph"],
                                                   curie_map)
    if message.get("knowledge_graph") is not None:
        kgraph = message["knowledge_graph"]
        new_message["knowledge_graph"] = {
            "nodes": {
                curie_map.get(knode_id, [knode_id])[0]: knode
                for knode_id, knode in kgraph["nodes"].items()
            },
            "edges": {
                kedge_id: fix_kedge(kedge, curie_map)
                for kedge_id, kedge in kgraph["edges"].items()
            },
        }
    if message.get("results") is not None:
        results = message["results"]
        new_message["results"] = [
            fix_result(result, curie_map) for result in results
        ]
    return new_message
예제 #5
0
def filter_by_curie_mapping(
    message: Message,
    curie_mapping: dict[str, list[str]],
    kp_id: str = "KP",
):
    """
    Filter a message to ensure that all results
    contain the bindings specified in the curie_mapping
    """

    filtered_msg = Message()

    # Only keep results where there is a node binding
    # that connects to our given kgraph_node_id
    filtered_msg.results = {
        result.copy()
        for result in message.results
        if result_contains_node_bindings(result, curie_mapping)
    }

    # Construct result-specific knowledge graph
    filtered_msg.knowledge_graph = KnowledgeGraph(
        nodes={
            binding.id: message.knowledge_graph.nodes[binding.id]
            for result in filtered_msg.results
            for _, bindings in result.node_bindings.items()
            for binding in bindings
        },
        edges={
            binding.id: message.knowledge_graph.edges[binding.id]
            for result in filtered_msg.results
            for _, bindings in result.edge_bindings.items()
            for binding in bindings
        },
    )

    return filtered_msg
예제 #6
0
async def normalize_message(app: FastAPI, message: Message) -> Message:
    """
    Given a TRAPI message, updates the message to include a
    normalized qgraph, kgraph, and results
    """
    merged_qgraph = await normalize_qgraph(app, message.query_graph)
    merged_kgraph, node_id_map, edge_id_map = await normalize_kgraph(
        app, message.knowledge_graph)
    merged_results = await normalize_results(message.results, node_id_map,
                                             edge_id_map)

    return Message.parse_obj({
        'query_graph': merged_qgraph,
        'knowledge_graph': merged_kgraph,
        'results': merged_results
    })
예제 #7
0
async def query(request: Request, *, jaccard_like: bool = False) -> Message:
    """Score answers.

    This is mostly glue around the heavy lifting in ranker_obj.Ranker
    """
    message = request.message.dict()
    kgraph = message['knowledge_graph']
    answers = message['results']

    # resistance distance ranking
    pr = Ranker(message)
    answers = pr.rank(answers, jaccard_like=jaccard_like)

    # finish
    message['results'] = answers
    return Message(**message)
예제 #8
0
async def query(request: Request, *, max_connectivity: int = -1) -> Message:
    """Fetch answers to question."""
    message = request.message.dict()
    neo4j = Neo4jDatabase(
        url=NEO4J_URL,
        credentials={
            'username': NEO4J_USER,
            'password': NEO4J_PASSWORD,
        },
    )
    qgraph = message['query_graph']
    cypher = get_query(
        qgraph,
        max_connectivity=max_connectivity,
    )
    message = (await neo4j.arun(cypher))[0]
    message['query_graph'] = qgraph
    return Message(**message)
예제 #9
0
async def query(request: Request, *, max_results: int = 3) -> Message:
    """Prescreen subgraphs.

    Keep the top max_results, by their total edge weight.
    """
    message = request.message.dict()
    if max_results < 0:
        return message

    kgraph = message['knowledge_graph']
    answers = message['results']

    prescreen_scores = [answer['score'] for answer in answers]

    prescreen_sorting = [
        x[0] for x in heapq.nlargest(max_results,
                                     enumerate(prescreen_scores),
                                     key=operator.itemgetter(1))
    ]

    answers = [answers[i] for i in prescreen_sorting]

    node_ids = set()
    edge_ids = set()
    for answer in answers:
        these_node_ids = [nb['kg_id'] for nb in answer['node_bindings']]
        these_node_ids = flatten_semilist(these_node_ids)
        node_ids.update(these_node_ids)

        these_edge_ids = [eb['kg_id'] for eb in answer['edge_bindings']]
        these_edge_ids = flatten_semilist(these_edge_ids)
        edge_ids.update(these_edge_ids)

    if 'edges' in kgraph:
        kgraph['edges'] = [e for e in kgraph['edges'] if e['id'] in edge_ids]
        kgraph['nodes'] = [n for n in kgraph['nodes'] if n['id'] in node_ids]
        message['knowledge_graph'] = kgraph

    message['results'] = answers
    return Message(**message)
예제 #10
0
async def query(request: Request) -> Message:
    """Normalize."""
    message = request.message.dict()
    qgraph = message['query_graph']

    qcuries = {
        curie
        for node in qgraph['nodes'] if 'curie' in node
        for curie in ensure_list(node['curie'])
    }
    try:
        knode_ids = {
            node['id']
            for node in message['knowledge_graph']['nodes']
        }
    except (KeyError, TypeError):
        # knowledge graph is absent or malformed
        knode_ids = set()
    curies = {curie for curie in qcuries | knode_ids if curie}
    curie_map = dict(zip(curies, synonymize(*(curies))))
    for node in qgraph['nodes']:
        if not node.get('curie', None):
            continue
        node['curie'] = [curie_map[ci] for ci in ensure_list(node['curie'])]
    if not message.get('knowledge_graph',
                       None) or ('nodes' not in message['knowledge_graph']):
        return message
    for node in message['knowledge_graph']['nodes']:
        node['id'] = curie_map[node['id']]
    for edge in message['knowledge_graph']['edges']:
        edge['source_id'] = curie_map[edge['source_id']]
        edge['target_id'] = curie_map[edge['target_id']]
    if 'results' not in message:
        return message
    for result in message['results']:
        for binding in result['node_bindings']:
            binding['kg_id'] = curie_map[binding['kg_id']]

    return Message(**message)
예제 #11
0
async def query(request: Request) -> Message:
    """Minify message.

    for knowledge graph:
      * keep only node properties: id, name, type
      * keep only edge properties: id, source_id, target_id, type
    for results:
      * keep only qg_id, kg_id
    """
    message = request.message.dict()
    kgraph = message['knowledge_graph']
    results = message['results']

    kgraph['nodes'] = [{
        'id': node['id'],
        'name': node['name'],
        'type': node['type']
    } for node in kgraph['nodes']]
    kgraph['edges'] = [{
        'id': edge['id'],
        'source_id': edge['source_id'],
        'target_id': edge['target_id'],
        'type': edge['type']
    } for edge in kgraph['edges']]
    for result in results:
        result['node_bindings'] = [{
            'qg_id': nb['qg_id'],
            'kg_id': nb['kg_id']
        } for nb in result['node_bindings']]
        result['edge_bindings'] = [{
            'qg_id': eb['qg_id'],
            'kg_id': eb['kg_id']
        } for eb in result['edge_bindings']]

    message['knowledge_graph'] = kgraph
    message['results'] = results
    return Message(**message)
예제 #12
0
async def run_workflow(
    request: ReasonerQuery = Body(..., example=load_example("query")),
) -> Response:
    """Run workflow."""
    request_dict = request.dict(exclude_unset=True, )

    message = request_dict["message"]
    workflow = request_dict["workflow"]
    logger = gen_logger()
    qgraph = message["query_graph"]
    kgraph = {"nodes": {}, "edges": {}}
    if "knowledge_graph" in message.keys():
        if "nodes" in message["knowledge_graph"].keys():
            kgraph = message["knowledge_graph"]

    async with httpx.AsyncClient(verify=False, timeout=60.0) as client:
        for operation in workflow:
            service_operation_responses = []
            for service in SERVICES[operation["id"]]:
                url = service["url"]
                service_name = service["title"]
                logger.debug(
                    f"Requesting operation '{operation}' from {service_name}..."
                )
                try:
                    response = await post_safely(
                        url,
                        {
                            "message": message,
                            "workflow": [
                                operation,
                            ],
                            "submitter": "Workflow Runner",
                        },
                        client=client,
                        timeout=60.0,
                        logger=logger,
                        service_name=service_name,
                    )
                    logger.debug(
                        f"Received operation '{operation}' from {service_name}..."
                    )

                    try:
                        response = await post_safely(
                            NORMALIZER_URL + "/response", {
                                "message": response["message"],
                                "submitter": "Workflow Runner"
                            },
                            client=client,
                            timeout=60.0,
                            logger=logger,
                            service_name="node_normalizer")
                    except RuntimeError as e:
                        logger.warning({"error": str(e)})

                    service_operation_responses.append(response)

                except RuntimeError as e:
                    logger.warning({"error": str(e)})

                if not OPERATIONS[operation["id"]]["unique"] and len(
                        service_operation_responses) == 1:
                    # We only need one successful response for non-unique operations
                    break

            logger.debug(
                f"Merging {len(service_operation_responses)} responses for '{operation}'..."
            )
            m = Message(
                query_graph=QueryGraph.parse_obj(qgraph),
                knowledge_graph=KnowledgeGraph.parse_obj(kgraph),
            )
            for response in service_operation_responses:
                response["message"]["query_graph"] = qgraph
                m.update(Message.parse_obj(response["message"]))
            message = m.dict()

    return Response(
        message=message,
        workflow=workflow,
        logs=logger.handlers[0].store,
    )
예제 #13
0
    async def setup(
        self,
        qgraph: dict,
    ):
        """Set up."""

        # Update qgraph identifiers
        message = Message.parse_obj({"query_graph": qgraph})
        curies = get_curies(message)
        if len(curies):
            await self.synonymizer.load_curies(*curies)
            curie_map = self.synonymizer.map(curies, self.preferred_prefixes)
            map_qgraph_curies(message.query_graph, curie_map, primary=True)

        self.qgraph = message.query_graph.dict()

        # Fill in missing categories and predicates using normalizer
        await fill_categories_predicates(self.qgraph, self.logger)

        # Initialize registry
        registry = Registry(settings.kpregistry_url, self.logger)

        # Generate traversal plan
        self.plan, kps = await generate_plan(
            self.qgraph,
            kp_registry=registry,
            logger=self.logger,
        )

        # extract KP preferred prefixes from plan
        self.kp_preferred_prefixes = dict()
        self.portal.tservers = dict()
        for kp_id, kp in kps.items():
            url = kp["url"][:-5] + "meta_knowledge_graph"
            try:
                async with httpx.AsyncClient() as client:
                    response = await client.get(
                        url,
                        timeout=10,
                    )
                response.raise_for_status()
                meta_kg = response.json()
                MetaKnowledgeGraph.parse_obj(meta_kg)
                self.kp_preferred_prefixes[kp_id] = {
                    category: data["id_prefixes"]
                    for category, data in meta_kg["nodes"].items()
                }
            except (httpx.ConnectError, httpx.ConnectTimeout,
                    httpx.ReadTimeout) as err:
                self.logger.warning(
                    "Unable to get meta knowledge graph from KP {}: {}".format(
                        kp_id,
                        str(err),
                    ), )
                self.kp_preferred_prefixes[kp_id] = dict()
            except httpx.HTTPStatusError as e:
                self.logger.warning(
                    "Received error response from /meta_knowledge_graph for KP {}: {}"
                    .format(
                        kp_id,
                        e.response.text,
                    ), )
                self.kp_preferred_prefixes[kp_id] = dict()
            except JSONDecodeError as err:
                self.logger.warning(
                    "Unable to parse meta knowledge graph from KP {}: {}".
                    format(
                        kp_id,
                        str(err),
                    ), )
                self.kp_preferred_prefixes[kp_id] = dict()
            except pydantic.ValidationError as err:
                self.logger.warning(
                    "Meta knowledge graph from KP {} is non-compliant: {}".
                    format(
                        kp_id,
                        str(err),
                    ), )
                self.kp_preferred_prefixes[kp_id] = dict()
            except Exception as err:
                self.logger.warning(
                    "Something went wrong while parsing meta knowledge graph from KP {}: {}"
                    .format(
                        kp_id,
                        str(err),
                    ), )
                self.kp_preferred_prefixes[kp_id] = dict()

            self.portal.tservers[kp_id] = ThrottledServer(
                kp_id,
                url=kp["url"],
                request_qty=1,
                request_duration=1,
                preproc=self.get_processor(self.kp_preferred_prefixes[kp_id]),
                postproc=self.get_processor(self.preferred_prefixes),
                logger=self.logger,
            )
        self.kps = {
            kp_id: KnowledgeProvider(
                details,
                self.portal,
                kp_id,
            )
            for kp_id, details in kps.items()
        }
예제 #14
0
async def query(request: Request) -> Message:
    """Add support to message.

    Add support edges to knowledge_graph and bindings to results.
    """
    message = request.message.dict()

    kgraph = message['knowledge_graph']
    qgraph = message['query_graph']
    answers = message['results']

    # get cache if possible
    try:
        cache = Cache(
            redis_host=CACHE_HOST,
            redis_port=CACHE_PORT,
            redis_db=CACHE_DB,
            redis_password=CACHE_PASSWORD,
        )
    except Exception as err:
        logger.exception(err)
        cache = None

    redis_batch_size = 100

    async with OmnicorpSupport() as supporter:
        # get all node supports

        keys = [
            f"{supporter.__class__.__name__}({node['id']})"
            for node in kgraph['nodes']
        ]
        values = []
        for batch in batches(keys, redis_batch_size):
            values.extend(cache.mget(*batch))

        jobs = [
            count_node_pmids(supporter, node, key, value, cache)
            for node, value, key in zip(kgraph['nodes'], values, keys)
        ]

        # Generate a set of pairs of node curies
        pair_to_answer = defaultdict(set)  # a map of node pairs to answers
        for ans_idx, answer_map in enumerate(answers):

            # Get all nodes that are not part of sets and densely connect them
            nodes = sorted([
                nb['kg_id'] for nb in answer_map['node_bindings']
                if isinstance(nb['kg_id'], str)
            ])
            for node_pair in combinations(nodes, 2):
                pair_to_answer[node_pair].add(ans_idx)

            # For all nodes that are within sets, connect them to all nodes that are not in sets
            set_nodes_list_list = [
                nb['kg_id'] for nb in answer_map['node_bindings']
                if isinstance(nb['kg_id'], list)
            ]
            set_nodes = [n for el in set_nodes_list_list for n in el]
            for set_node in set_nodes:
                for node in nodes:
                    node_pair = tuple(sorted((node, set_node)))
                    pair_to_answer[node_pair].add(ans_idx)

        # get all pair supports
        cached_prefixes = cache.get('OmnicorpPrefixes') if cache else None

        keys = [
            f"{supporter.__class__.__name__}_count({pair[0]},{pair[1]})"
            for pair in pair_to_answer
        ]
        values = []
        for batch in batches(keys, redis_batch_size):
            values.extend(cache.mget(*batch))

        jobs.extend([
            count_shared_pmids(
                supporter,
                support_idx,
                pair,
                key,
                value,
                cache,
                cached_prefixes,
                kgraph,
                pair_to_answer,
                answers,
            ) for support_idx, (
                pair, value,
                key) in enumerate(zip(pair_to_answer, values, keys))
        ])
        await asyncio.gather(*jobs)

    message['knowledge_graph'] = kgraph
    message['results'] = answers
    return Message(**message)
async def query(request: Request, *, exclude_sets=False) -> Message:
    """Compute informativeness weights for edges."""
    message = request.message.dict()
    qgraph = message['query_graph']
    results = message['results']

    qnodes = qgraph['nodes']
    qedges = qgraph['edges']

    # knode_map = {knode['id']: knode for knode in knodes}
    qnode_map = {qnode['id']: qnode for qnode in qnodes}
    qedge_map = {qedge['id']: qedge for qedge in qedges}

    driver = Neo4jDatabase(
        url=NEO4J_URL,
        credentials={
            'username': NEO4J_USER,
            'password': NEO4J_PASSWORD,
        },
    )
    redges_by_id = dict()
    count_plans = defaultdict(lambda: defaultdict(list))
    for kdx, result in enumerate(results):
        rgraph = get_rgraph(result, message)
        redges_by_id.update({(kdx, redge['id']): redge
                             for redge in rgraph['edges']})

        for redge in rgraph['edges']:
            if (not exclude_sets) or qnode_map[redge['qg_target_id']].get(
                    'set', False):
                count_plans[redge['kg_source_id']][(
                    redge['eb']['qg_id'], redge['qg_target_id'])].append(
                        (kdx, redge['id']))
            if (not exclude_sets) or qnode_map[redge['qg_source_id']].get(
                    'set', False):
                count_plans[redge['kg_target_id']][(
                    redge['eb']['qg_id'], redge['qg_source_id'])].append(
                        (kdx, redge['id']))

    count_to_redge = {}
    for ldx, batch in enumerate(batches(list(count_plans.keys()), 1000)):
        batch_bits = []
        for idx, ksource_id in enumerate(batch):
            sets = []
            plan = count_plans[ksource_id]
            anchor_node_reference = NodeReference({
                'id': f'n{idx:04d}',
                'curie': ksource_id,
                'type': 'named_thing'
            })
            anchor_node_reference = str(anchor_node_reference)
            base = f"MATCH ({anchor_node_reference}) "
            for jdx, (qlink, redge_ids) in enumerate(plan.items()):
                cypher_counts = []
                qedge_id, qtarget_id = qlink
                count_id = f"c{idx:03d}{chr(97 + jdx)}"
                qedge = qedge_map[qedge_id]
                edge_reference = EdgeReference(qedge, anonymous=True)
                anon_node_reference = NodeReference({
                    **qnode_map[qtarget_id],
                    'id':
                    count_id,
                })
                if qedge['source_id'] == qtarget_id:
                    source_reference = anon_node_reference
                    target_reference = anchor_node_reference
                elif qedge['target_id'] == qtarget_id:
                    source_reference = anchor_node_reference
                    target_reference = anon_node_reference
                cypher_counts.append(
                    f"{anon_node_reference.name}: count(DISTINCT {anon_node_reference.name})"
                )
                count_to_redge[count_id] = redge_ids
                sets.append(
                    f'MATCH ({source_reference}){edge_reference}({target_reference})'
                    + ' RETURN {' + ', '.join(cypher_counts) + '} as output')
            batch_bits.append(' UNION ALL '.join(sets))
        cypher = ' UNION ALL '.join(batch_bits)
        response = driver.run(cypher)

        degrees = {
            key: value
            for result in response for key, value in result['output'].items()
        }

        for key in degrees:
            for redge_id in count_to_redge[key]:
                eb = redges_by_id[redge_id]['eb']
                eb['weight'] = eb.get('weight', 1.0) / degrees[key]

    message['results'] = results
    return Message(**message)
예제 #16
0
    async def process_batch(
        self,
    ):
        """Set up a subscriber to process batching."""
        # Initialize the TAT
        #
        # TAT = Theoretical Arrival Time
        # When the next request should be sent
        # to adhere to the rate limit.
        #
        # This is an implementation of the GCRA algorithm
        # More information can be found here:
        # https://dev.to/astagi/rate-limiting-using-python-and-redis-58gk
        if self.request_qty > 0:
            interval = self.request_duration / self.request_qty
            tat = datetime.datetime.utcnow() + interval

        while True:
            # Get everything in the stream or wait for something to show up
            priority, (
                request_id,
                payload,
                response_queue,
            ) = await self.request_queue.get()
            priorities = {request_id: priority}
            request_value_mapping = {request_id: payload}
            response_queues = {request_id: response_queue}
            while True:
                if (
                    self.max_batch_size is not None
                    and len(request_value_mapping) == self.max_batch_size
                ):
                    break
                try:
                    priority, (
                        request_id,
                        payload,
                        response_queue,
                    ) = self.request_queue.get_nowait()
                except QueueEmpty:
                    break
                priorities[request_id] = priority
                request_value_mapping[request_id] = payload
                response_queues[request_id] = response_queue

            # Extract a curie mapping from each request
            request_curie_mapping = {
                request_id: get_curies(request_value.message.query_graph)
                for request_id, request_value in request_value_mapping.items()
            }

            # Find requests that are the same (those that we can merge)
            # This disregards non-matching IDs because the IDs have been
            # removed with the extract_curie method
            stripped_qgraphs = {
                request_id: remove_curies(request.message.query_graph)
                for request_id, request in request_value_mapping.items()
            }
            first_value = next(iter(stripped_qgraphs.values()))

            batch_request_ids = get_keys_with_value(
                stripped_qgraphs,
                first_value,
            )

            # Re-queue the un-selected requests
            for request_id in request_value_mapping:
                if request_id not in batch_request_ids:
                    await self.request_queue.put(
                        (
                            priorities[request_id],
                            (
                                request_id,
                                request_value_mapping[request_id],
                                response_queues[request_id],
                            ),
                        )
                    )

            request_value_mapping = {
                k: v for k, v in request_value_mapping.items() if k in batch_request_ids
            }

            # Filter curie mapping to only include matching requests
            request_curie_mapping = {
                k: v for k, v in request_curie_mapping.items() if k in batch_request_ids
            }

            # Pull first value from request_value_mapping
            # to use as a template for our merged request
            merged_request_value = copy.deepcopy(
                next(iter(request_value_mapping.values()))
            )

            # Remove qnode ids
            for qnode in merged_request_value.message.query_graph.nodes.values():
                qnode.ids = None

            # Update merged request using curie mapping
            for curie_mapping in request_curie_mapping.values():
                for node_id, node_curies in curie_mapping.items():
                    node = merged_request_value.message.query_graph.nodes[node_id]
                    if node.ids is None:
                        node.ids = node_curies.copy()
                    else:
                        node.ids.extend(node_curies)
            # TODO replace qnode.ids with a HashableSet so that this can be safely removed
            for qnode in merged_request_value.message.query_graph.nodes.values():
                if qnode.ids:
                    qnode.ids = list(set(qnode.ids))

            response_values = dict()
            try:
                # Make request
                self.logger.info(
                    "[{id}] Sending request made of {subrequests} subrequests ({curies} curies)".format(
                        id=self.id,
                        subrequests=len(request_curie_mapping),
                        curies=" x ".join(
                            str(len(qnode.ids or []))
                            for qnode in merged_request_value.message.query_graph.nodes.values()
                        ),
                    )
                )
                self.logger.context = self.id
                await self.preproc(merged_request_value, self.logger)
                # TODO rewrite this whole function to use pydantic model
                merged_request_value = merged_request_value.dict()
                merged_request_value["submitter"] = "infores:aragorn"
                merged_request_value = remove_null_values(merged_request_value)
                async with httpx.AsyncClient() as client:
                    response = await client.post(
                        self.url,
                        json=merged_request_value,
                        timeout=self.timeout,
                    )
                if response.status_code == 429:
                    # reset TAT
                    interval = self.request_duration / self.request_qty
                    tat = datetime.datetime.utcnow() + interval
                    # re-queue requests
                    for request_id in request_value_mapping:
                        await self.request_queue.put(
                            (
                                priorities[request_id],
                                (
                                    request_id,
                                    request_value_mapping[request_id],
                                    response_queues[request_id],
                                ),
                            )
                        )
                    # try again later
                    continue

                response.raise_for_status()

                # Parse with reasoner_pydantic to validate
                response_body = ReasonerResponse.parse_obj(response.json())
                await self.postproc(response_body, self.logger)
                message = response_body.message
                results = message.results or []
                self.logger.info(
                    "[{}] Received response with {} results in {} seconds".format(
                        self.id,
                        len(results),
                        response.elapsed.total_seconds(),
                    )
                )

                try:
                    if len(request_curie_mapping) == 1:
                        request_id = next(iter(request_curie_mapping))
                        # Make a copy
                        response_values[request_id] = ReasonerResponse(
                            message=Message()
                        )
                        response_values[
                            request_id
                        ].message.query_graph = request_value_mapping[
                            request_id
                        ].message.query_graph.copy()
                        response_values[request_id].message.knowledge_graph = (
                            message.knowledge_graph
                            or KnowledgeGraph(nodes={}, edges={})
                        ).copy()
                        response_values[request_id].message.results = (
                            message.results or HashableSet(__root__=[])
                        ).copy()
                    else:
                        # Split using the request_curie_mapping
                        for request_id, curie_mapping in request_curie_mapping.items():
                            filtered_msg = filter_by_curie_mapping(
                                message, curie_mapping, kp_id=self.id
                            )
                            filtered_msg.query_graph = request_value_mapping[
                                request_id
                            ].message.query_graph.copy()
                            response_values[request_id] = ReasonerResponse(
                                message=filtered_msg
                            )
                except Exception as err:
                    # Raise more descriptive error message of response message parsing
                    raise Exception(
                        "[{}] Failed to parse message response: {} with Error: {}".format(
                            self.id,
                            response.json(),
                            err,
                        )
                    )
            except (
                asyncio.exceptions.TimeoutError,
                httpx.RequestError,
                httpx.HTTPStatusError,
                JSONDecodeError,
                pydantic.ValidationError,
                Exception,
            ) as e:
                for request_id, curie_mapping in request_curie_mapping.items():
                    response_values[request_id] = ReasonerResponse(
                        message=request_value_mapping[request_id].message.copy()
                    )
                if isinstance(e, asyncio.TimeoutError):
                    self.logger.warning(
                        {
                            "message": f"{self.id} took >60 seconds to respond",
                            "error": str(e),
                            "request": elide_curies(merged_request_value),
                        }
                    )
                elif isinstance(e, httpx.ReadTimeout):
                    self.logger.warning(
                        {
                            "message": f"{self.id} took >60 seconds to respond",
                            "error": str(e),
                            "request": log_request(e.request),
                        }
                    )
                elif isinstance(e, httpx.RequestError):
                    # Log error
                    self.logger.warning(
                        {
                            "message": f"Request Error contacting {self.id}",
                            "error": str(e),
                            "request": log_request(e.request),
                        }
                    )
                elif isinstance(e, httpx.HTTPStatusError):
                    # Log error with response
                    self.logger.warning(
                        {
                            "message": f"Response Error contacting {self.id}",
                            "error": str(e),
                            "request": log_request(e.request),
                            "response": log_response(e.response),
                        }
                    )
                elif isinstance(e, JSONDecodeError):
                    # Log error with response
                    self.logger.warning(
                        {
                            "message": f"Received bad JSON data from {self.id}",
                            "request": e.request,
                            "response": e.response.text,
                            "error": str(e),
                        }
                    )
                elif isinstance(e, pydantic.ValidationError):
                    self.logger.warning(
                        {
                            "message": f"Received non-TRAPI compliant response from {self.id}",
                            "error": str(e),
                        }
                    )
                else:
                    self.logger.warning(
                        {
                            "message": f"Something went wrong while querying {self.id}",
                            "error": str(e),
                        }
                    )

            for request_id, response_value in response_values.items():
                # Write finished value to DB
                await response_queues[request_id].put(response_value)

            # if request_qty == 0 we don't enforce the rate limit
            if self.request_qty > 0:
                time_remaining_seconds = (
                    tat - datetime.datetime.utcnow()
                ).total_seconds()

                # Wait for TAT
                if time_remaining_seconds > 0:
                    await asyncio.sleep(time_remaining_seconds)

                # Update TAT
                tat = datetime.datetime.utcnow() + interval
예제 #17
0
async def query(
    request: Request,
    relevance: Optional[float] = Query(
        0.0025,
        description='portion of cooccurrence pubs relevant to question',
    ),
    wt_min: Optional[float] = Query(
        0.0,
        description='minimum weight (at 0 pubs)',
    ),
    wt_max: Optional[float] = Query(
        1.0,
        description='maximum weight (at inf pubs)',
    ),
    p50: Optional[float] = Query(
        2.0,
        description='pubs at 50% of wt_max',
    ),
) -> Message:
    """Weight kgraph edges based on metadata.

    "19 pubs from CTD is a 1, and 2 should at least be 0.5"
        - cbizon
    """
    message = request.message.dict()

    def sigmoid(x):
        """Scale with partial sigmoid - the right (concave down) half.

        Such that:
           f(0) = wt_min
           f(inf) = wt_max
           f(p50) = 0.5 * wt_max
        """
        a = 2 * (wt_max - wt_min)
        r = 0.5 * wt_max
        c = wt_max - 2 * wt_min
        k = 1 / p50 * (math.log(r + c) - math.log(a - r - c))
        return a / (1 + math.exp(-k * x)) - c

    kgraph = message['knowledge_graph']
    node_pubs = {
        n['id']: n.get('omnicorp_article_count', None)
        for n in kgraph['nodes']
    }
    all_pubs = 27840000

    results = message['results']

    # ensure that each edge_binding has a single kg_id
    for result in results:
        result['edge_bindings'] = [
            eb for ebs in result['edge_bindings']
            for eb in ([{
                'qg_id': ebs['qg_id'],
                'kg_id': kg_id,
            } for kg_id in ebs['kg_id']] if isinstance(ebs['kg_id'], list
                                                       ) else [ebs])
        ]

    # map kedges to edge_bindings
    krmap = defaultdict(list)
    for result in results:
        for eb in result['edge_bindings']:
            assert isinstance(eb['kg_id'], str)
            eb['weight'] = eb.get('weight', 1.0)
            krmap[eb['kg_id']].append(eb)

    edges = kgraph['edges']
    for edge in edges:
        edge_pubs = edge.get('num_publications',
                             len(edge.get('publications', [])))
        if edge['type'] == 'literature_co-occurrence':
            source_pubs = int(node_pubs[edge['source_id']])
            target_pubs = int(node_pubs[edge['target_id']])

            cov = (edge_pubs / all_pubs
                   ) - (source_pubs / all_pubs) * (target_pubs / all_pubs)
            cov = max((cov, 0.0))
            effective_pubs = cov * all_pubs * relevance
        else:
            effective_pubs = edge_pubs + 1  # consider the curation a pub

        for redge in krmap[edge['id']]:
            redge['weight'] = redge.get('weight',
                                        1.0) * sigmoid(effective_pubs)

    message['knowledge_graph'] = kgraph
    return Message(**message)
예제 #18
0
]

table = "\n\n"

table += "          Benchmark Name            |  Output Size (MB)  |  Total Time (s)\n"
table += "--------------------------------------------------------------------------\n"

for b in benchmarks:
    input_messages = [
        generate_message_parameterized(**b["params"])
        for _ in range(b["msg_count"])
    ]

    start = time.time()

    combined_msg = Message(results=[])

    print(f"Running benchmark {b['name']}")
    for m in tqdm(input_messages):
        combined_msg.update(m)

    end = time.time()

    # Compute file size
    print("\nComputing final message size, this may take a while...")
    output_file_size = len(combined_msg.json())
    output_file_size = 0

    table += f"  {b['name'].center(32)}  |  {output_file_size/1e6:16}  |  {end - start:14.2f}\n"

print(table)