def filter_by_curie_mapping( message: Message, curie_mapping: dict[str, list[str]], kp_id: str = "KP", ) -> Message: """ Filter a message to ensure that all results contain the bindings specified in the curie_mapping """ # Only keep results where there is a node binding # that connects to our given kgraph_node_id results = [ result for result in (message.get("results") or []) if result_contains_node_bindings(result, curie_mapping) ] # Construct result-specific knowledge graph kgraph = { "nodes": { binding["id"]: message["knowledge_graph"]["nodes"][binding["id"]] for result in results for _, bindings in result["node_bindings"].items() for binding in bindings }, "edges": { binding["id"]: message["knowledge_graph"]["edges"][binding["id"]] for result in results for _, bindings in result["edge_bindings"].items() for binding in bindings }, } return kgraph, results
async def _query( self, query: Query, priority: float = 0, # lowest goes first timeout: Optional[float] = 60.0, ) -> ReasonerResponse: """Queue up a query for batching and return when completed""" if self.worker is None: raise RuntimeError( "Cannot send a request until a worker is running - enter the context" ) # TODO figure out a way to remove this conversion query = Query.parse_obj(query) response_queue = asyncio.Queue() qgraphs = get_canonical_qgraphs(query.message.query_graph) for qgraph in qgraphs: subquery = Query(message=Message(query_graph=qgraph)) # Queue query for processing request_id = str(uuid.uuid1()) await self.request_queue.put( ( (priority, next(self.counter)), (request_id, subquery, response_queue), ) ) combined_output = ReasonerResponse.parse_obj( { "message": { "knowledge_graph": {"nodes": {}, "edges": {}}, "results": [], } } ) for _ in qgraphs: # Wait for response output: Union[ReasonerResponse, Exception] = await asyncio.wait_for( response_queue.get(), timeout=None, ) if isinstance(output, Exception): raise output output.message.query_graph = None combined_output.message.update(output.message) combined_output.message.query_graph = query.message.query_graph return combined_output.dict()
def get_curies(message: Message) -> list[str]: """Get all node curies used in message. Do not examine kedge source and target ids. There ought to be corresponding knodes. """ curies = set() if message.get("query_graph") is not None: for qnode in message["query_graph"]["nodes"].values(): if qnode_id := qnode.get("ids", False): curies |= set(qnode_id)
def apply_curie_map(message: Message, curie_map: dict[str, str]) -> Message: """Translate all pinned qnodes to preferred prefix.""" new_message = dict() new_message["query_graph"] = map_qgraph_curies(message["query_graph"], curie_map) if message.get("knowledge_graph") is not None: kgraph = message["knowledge_graph"] new_message["knowledge_graph"] = { "nodes": { curie_map.get(knode_id, [knode_id])[0]: knode for knode_id, knode in kgraph["nodes"].items() }, "edges": { kedge_id: fix_kedge(kedge, curie_map) for kedge_id, kedge in kgraph["edges"].items() }, } if message.get("results") is not None: results = message["results"] new_message["results"] = [ fix_result(result, curie_map) for result in results ] return new_message
def filter_by_curie_mapping( message: Message, curie_mapping: dict[str, list[str]], kp_id: str = "KP", ): """ Filter a message to ensure that all results contain the bindings specified in the curie_mapping """ filtered_msg = Message() # Only keep results where there is a node binding # that connects to our given kgraph_node_id filtered_msg.results = { result.copy() for result in message.results if result_contains_node_bindings(result, curie_mapping) } # Construct result-specific knowledge graph filtered_msg.knowledge_graph = KnowledgeGraph( nodes={ binding.id: message.knowledge_graph.nodes[binding.id] for result in filtered_msg.results for _, bindings in result.node_bindings.items() for binding in bindings }, edges={ binding.id: message.knowledge_graph.edges[binding.id] for result in filtered_msg.results for _, bindings in result.edge_bindings.items() for binding in bindings }, ) return filtered_msg
async def normalize_message(app: FastAPI, message: Message) -> Message: """ Given a TRAPI message, updates the message to include a normalized qgraph, kgraph, and results """ merged_qgraph = await normalize_qgraph(app, message.query_graph) merged_kgraph, node_id_map, edge_id_map = await normalize_kgraph( app, message.knowledge_graph) merged_results = await normalize_results(message.results, node_id_map, edge_id_map) return Message.parse_obj({ 'query_graph': merged_qgraph, 'knowledge_graph': merged_kgraph, 'results': merged_results })
async def query(request: Request, *, jaccard_like: bool = False) -> Message: """Score answers. This is mostly glue around the heavy lifting in ranker_obj.Ranker """ message = request.message.dict() kgraph = message['knowledge_graph'] answers = message['results'] # resistance distance ranking pr = Ranker(message) answers = pr.rank(answers, jaccard_like=jaccard_like) # finish message['results'] = answers return Message(**message)
async def query(request: Request, *, max_connectivity: int = -1) -> Message: """Fetch answers to question.""" message = request.message.dict() neo4j = Neo4jDatabase( url=NEO4J_URL, credentials={ 'username': NEO4J_USER, 'password': NEO4J_PASSWORD, }, ) qgraph = message['query_graph'] cypher = get_query( qgraph, max_connectivity=max_connectivity, ) message = (await neo4j.arun(cypher))[0] message['query_graph'] = qgraph return Message(**message)
async def query(request: Request, *, max_results: int = 3) -> Message: """Prescreen subgraphs. Keep the top max_results, by their total edge weight. """ message = request.message.dict() if max_results < 0: return message kgraph = message['knowledge_graph'] answers = message['results'] prescreen_scores = [answer['score'] for answer in answers] prescreen_sorting = [ x[0] for x in heapq.nlargest(max_results, enumerate(prescreen_scores), key=operator.itemgetter(1)) ] answers = [answers[i] for i in prescreen_sorting] node_ids = set() edge_ids = set() for answer in answers: these_node_ids = [nb['kg_id'] for nb in answer['node_bindings']] these_node_ids = flatten_semilist(these_node_ids) node_ids.update(these_node_ids) these_edge_ids = [eb['kg_id'] for eb in answer['edge_bindings']] these_edge_ids = flatten_semilist(these_edge_ids) edge_ids.update(these_edge_ids) if 'edges' in kgraph: kgraph['edges'] = [e for e in kgraph['edges'] if e['id'] in edge_ids] kgraph['nodes'] = [n for n in kgraph['nodes'] if n['id'] in node_ids] message['knowledge_graph'] = kgraph message['results'] = answers return Message(**message)
async def query(request: Request) -> Message: """Normalize.""" message = request.message.dict() qgraph = message['query_graph'] qcuries = { curie for node in qgraph['nodes'] if 'curie' in node for curie in ensure_list(node['curie']) } try: knode_ids = { node['id'] for node in message['knowledge_graph']['nodes'] } except (KeyError, TypeError): # knowledge graph is absent or malformed knode_ids = set() curies = {curie for curie in qcuries | knode_ids if curie} curie_map = dict(zip(curies, synonymize(*(curies)))) for node in qgraph['nodes']: if not node.get('curie', None): continue node['curie'] = [curie_map[ci] for ci in ensure_list(node['curie'])] if not message.get('knowledge_graph', None) or ('nodes' not in message['knowledge_graph']): return message for node in message['knowledge_graph']['nodes']: node['id'] = curie_map[node['id']] for edge in message['knowledge_graph']['edges']: edge['source_id'] = curie_map[edge['source_id']] edge['target_id'] = curie_map[edge['target_id']] if 'results' not in message: return message for result in message['results']: for binding in result['node_bindings']: binding['kg_id'] = curie_map[binding['kg_id']] return Message(**message)
async def query(request: Request) -> Message: """Minify message. for knowledge graph: * keep only node properties: id, name, type * keep only edge properties: id, source_id, target_id, type for results: * keep only qg_id, kg_id """ message = request.message.dict() kgraph = message['knowledge_graph'] results = message['results'] kgraph['nodes'] = [{ 'id': node['id'], 'name': node['name'], 'type': node['type'] } for node in kgraph['nodes']] kgraph['edges'] = [{ 'id': edge['id'], 'source_id': edge['source_id'], 'target_id': edge['target_id'], 'type': edge['type'] } for edge in kgraph['edges']] for result in results: result['node_bindings'] = [{ 'qg_id': nb['qg_id'], 'kg_id': nb['kg_id'] } for nb in result['node_bindings']] result['edge_bindings'] = [{ 'qg_id': eb['qg_id'], 'kg_id': eb['kg_id'] } for eb in result['edge_bindings']] message['knowledge_graph'] = kgraph message['results'] = results return Message(**message)
async def run_workflow( request: ReasonerQuery = Body(..., example=load_example("query")), ) -> Response: """Run workflow.""" request_dict = request.dict(exclude_unset=True, ) message = request_dict["message"] workflow = request_dict["workflow"] logger = gen_logger() qgraph = message["query_graph"] kgraph = {"nodes": {}, "edges": {}} if "knowledge_graph" in message.keys(): if "nodes" in message["knowledge_graph"].keys(): kgraph = message["knowledge_graph"] async with httpx.AsyncClient(verify=False, timeout=60.0) as client: for operation in workflow: service_operation_responses = [] for service in SERVICES[operation["id"]]: url = service["url"] service_name = service["title"] logger.debug( f"Requesting operation '{operation}' from {service_name}..." ) try: response = await post_safely( url, { "message": message, "workflow": [ operation, ], "submitter": "Workflow Runner", }, client=client, timeout=60.0, logger=logger, service_name=service_name, ) logger.debug( f"Received operation '{operation}' from {service_name}..." ) try: response = await post_safely( NORMALIZER_URL + "/response", { "message": response["message"], "submitter": "Workflow Runner" }, client=client, timeout=60.0, logger=logger, service_name="node_normalizer") except RuntimeError as e: logger.warning({"error": str(e)}) service_operation_responses.append(response) except RuntimeError as e: logger.warning({"error": str(e)}) if not OPERATIONS[operation["id"]]["unique"] and len( service_operation_responses) == 1: # We only need one successful response for non-unique operations break logger.debug( f"Merging {len(service_operation_responses)} responses for '{operation}'..." ) m = Message( query_graph=QueryGraph.parse_obj(qgraph), knowledge_graph=KnowledgeGraph.parse_obj(kgraph), ) for response in service_operation_responses: response["message"]["query_graph"] = qgraph m.update(Message.parse_obj(response["message"])) message = m.dict() return Response( message=message, workflow=workflow, logs=logger.handlers[0].store, )
async def setup( self, qgraph: dict, ): """Set up.""" # Update qgraph identifiers message = Message.parse_obj({"query_graph": qgraph}) curies = get_curies(message) if len(curies): await self.synonymizer.load_curies(*curies) curie_map = self.synonymizer.map(curies, self.preferred_prefixes) map_qgraph_curies(message.query_graph, curie_map, primary=True) self.qgraph = message.query_graph.dict() # Fill in missing categories and predicates using normalizer await fill_categories_predicates(self.qgraph, self.logger) # Initialize registry registry = Registry(settings.kpregistry_url, self.logger) # Generate traversal plan self.plan, kps = await generate_plan( self.qgraph, kp_registry=registry, logger=self.logger, ) # extract KP preferred prefixes from plan self.kp_preferred_prefixes = dict() self.portal.tservers = dict() for kp_id, kp in kps.items(): url = kp["url"][:-5] + "meta_knowledge_graph" try: async with httpx.AsyncClient() as client: response = await client.get( url, timeout=10, ) response.raise_for_status() meta_kg = response.json() MetaKnowledgeGraph.parse_obj(meta_kg) self.kp_preferred_prefixes[kp_id] = { category: data["id_prefixes"] for category, data in meta_kg["nodes"].items() } except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout) as err: self.logger.warning( "Unable to get meta knowledge graph from KP {}: {}".format( kp_id, str(err), ), ) self.kp_preferred_prefixes[kp_id] = dict() except httpx.HTTPStatusError as e: self.logger.warning( "Received error response from /meta_knowledge_graph for KP {}: {}" .format( kp_id, e.response.text, ), ) self.kp_preferred_prefixes[kp_id] = dict() except JSONDecodeError as err: self.logger.warning( "Unable to parse meta knowledge graph from KP {}: {}". format( kp_id, str(err), ), ) self.kp_preferred_prefixes[kp_id] = dict() except pydantic.ValidationError as err: self.logger.warning( "Meta knowledge graph from KP {} is non-compliant: {}". format( kp_id, str(err), ), ) self.kp_preferred_prefixes[kp_id] = dict() except Exception as err: self.logger.warning( "Something went wrong while parsing meta knowledge graph from KP {}: {}" .format( kp_id, str(err), ), ) self.kp_preferred_prefixes[kp_id] = dict() self.portal.tservers[kp_id] = ThrottledServer( kp_id, url=kp["url"], request_qty=1, request_duration=1, preproc=self.get_processor(self.kp_preferred_prefixes[kp_id]), postproc=self.get_processor(self.preferred_prefixes), logger=self.logger, ) self.kps = { kp_id: KnowledgeProvider( details, self.portal, kp_id, ) for kp_id, details in kps.items() }
async def query(request: Request) -> Message: """Add support to message. Add support edges to knowledge_graph and bindings to results. """ message = request.message.dict() kgraph = message['knowledge_graph'] qgraph = message['query_graph'] answers = message['results'] # get cache if possible try: cache = Cache( redis_host=CACHE_HOST, redis_port=CACHE_PORT, redis_db=CACHE_DB, redis_password=CACHE_PASSWORD, ) except Exception as err: logger.exception(err) cache = None redis_batch_size = 100 async with OmnicorpSupport() as supporter: # get all node supports keys = [ f"{supporter.__class__.__name__}({node['id']})" for node in kgraph['nodes'] ] values = [] for batch in batches(keys, redis_batch_size): values.extend(cache.mget(*batch)) jobs = [ count_node_pmids(supporter, node, key, value, cache) for node, value, key in zip(kgraph['nodes'], values, keys) ] # Generate a set of pairs of node curies pair_to_answer = defaultdict(set) # a map of node pairs to answers for ans_idx, answer_map in enumerate(answers): # Get all nodes that are not part of sets and densely connect them nodes = sorted([ nb['kg_id'] for nb in answer_map['node_bindings'] if isinstance(nb['kg_id'], str) ]) for node_pair in combinations(nodes, 2): pair_to_answer[node_pair].add(ans_idx) # For all nodes that are within sets, connect them to all nodes that are not in sets set_nodes_list_list = [ nb['kg_id'] for nb in answer_map['node_bindings'] if isinstance(nb['kg_id'], list) ] set_nodes = [n for el in set_nodes_list_list for n in el] for set_node in set_nodes: for node in nodes: node_pair = tuple(sorted((node, set_node))) pair_to_answer[node_pair].add(ans_idx) # get all pair supports cached_prefixes = cache.get('OmnicorpPrefixes') if cache else None keys = [ f"{supporter.__class__.__name__}_count({pair[0]},{pair[1]})" for pair in pair_to_answer ] values = [] for batch in batches(keys, redis_batch_size): values.extend(cache.mget(*batch)) jobs.extend([ count_shared_pmids( supporter, support_idx, pair, key, value, cache, cached_prefixes, kgraph, pair_to_answer, answers, ) for support_idx, ( pair, value, key) in enumerate(zip(pair_to_answer, values, keys)) ]) await asyncio.gather(*jobs) message['knowledge_graph'] = kgraph message['results'] = answers return Message(**message)
async def query(request: Request, *, exclude_sets=False) -> Message: """Compute informativeness weights for edges.""" message = request.message.dict() qgraph = message['query_graph'] results = message['results'] qnodes = qgraph['nodes'] qedges = qgraph['edges'] # knode_map = {knode['id']: knode for knode in knodes} qnode_map = {qnode['id']: qnode for qnode in qnodes} qedge_map = {qedge['id']: qedge for qedge in qedges} driver = Neo4jDatabase( url=NEO4J_URL, credentials={ 'username': NEO4J_USER, 'password': NEO4J_PASSWORD, }, ) redges_by_id = dict() count_plans = defaultdict(lambda: defaultdict(list)) for kdx, result in enumerate(results): rgraph = get_rgraph(result, message) redges_by_id.update({(kdx, redge['id']): redge for redge in rgraph['edges']}) for redge in rgraph['edges']: if (not exclude_sets) or qnode_map[redge['qg_target_id']].get( 'set', False): count_plans[redge['kg_source_id']][( redge['eb']['qg_id'], redge['qg_target_id'])].append( (kdx, redge['id'])) if (not exclude_sets) or qnode_map[redge['qg_source_id']].get( 'set', False): count_plans[redge['kg_target_id']][( redge['eb']['qg_id'], redge['qg_source_id'])].append( (kdx, redge['id'])) count_to_redge = {} for ldx, batch in enumerate(batches(list(count_plans.keys()), 1000)): batch_bits = [] for idx, ksource_id in enumerate(batch): sets = [] plan = count_plans[ksource_id] anchor_node_reference = NodeReference({ 'id': f'n{idx:04d}', 'curie': ksource_id, 'type': 'named_thing' }) anchor_node_reference = str(anchor_node_reference) base = f"MATCH ({anchor_node_reference}) " for jdx, (qlink, redge_ids) in enumerate(plan.items()): cypher_counts = [] qedge_id, qtarget_id = qlink count_id = f"c{idx:03d}{chr(97 + jdx)}" qedge = qedge_map[qedge_id] edge_reference = EdgeReference(qedge, anonymous=True) anon_node_reference = NodeReference({ **qnode_map[qtarget_id], 'id': count_id, }) if qedge['source_id'] == qtarget_id: source_reference = anon_node_reference target_reference = anchor_node_reference elif qedge['target_id'] == qtarget_id: source_reference = anchor_node_reference target_reference = anon_node_reference cypher_counts.append( f"{anon_node_reference.name}: count(DISTINCT {anon_node_reference.name})" ) count_to_redge[count_id] = redge_ids sets.append( f'MATCH ({source_reference}){edge_reference}({target_reference})' + ' RETURN {' + ', '.join(cypher_counts) + '} as output') batch_bits.append(' UNION ALL '.join(sets)) cypher = ' UNION ALL '.join(batch_bits) response = driver.run(cypher) degrees = { key: value for result in response for key, value in result['output'].items() } for key in degrees: for redge_id in count_to_redge[key]: eb = redges_by_id[redge_id]['eb'] eb['weight'] = eb.get('weight', 1.0) / degrees[key] message['results'] = results return Message(**message)
async def process_batch( self, ): """Set up a subscriber to process batching.""" # Initialize the TAT # # TAT = Theoretical Arrival Time # When the next request should be sent # to adhere to the rate limit. # # This is an implementation of the GCRA algorithm # More information can be found here: # https://dev.to/astagi/rate-limiting-using-python-and-redis-58gk if self.request_qty > 0: interval = self.request_duration / self.request_qty tat = datetime.datetime.utcnow() + interval while True: # Get everything in the stream or wait for something to show up priority, ( request_id, payload, response_queue, ) = await self.request_queue.get() priorities = {request_id: priority} request_value_mapping = {request_id: payload} response_queues = {request_id: response_queue} while True: if ( self.max_batch_size is not None and len(request_value_mapping) == self.max_batch_size ): break try: priority, ( request_id, payload, response_queue, ) = self.request_queue.get_nowait() except QueueEmpty: break priorities[request_id] = priority request_value_mapping[request_id] = payload response_queues[request_id] = response_queue # Extract a curie mapping from each request request_curie_mapping = { request_id: get_curies(request_value.message.query_graph) for request_id, request_value in request_value_mapping.items() } # Find requests that are the same (those that we can merge) # This disregards non-matching IDs because the IDs have been # removed with the extract_curie method stripped_qgraphs = { request_id: remove_curies(request.message.query_graph) for request_id, request in request_value_mapping.items() } first_value = next(iter(stripped_qgraphs.values())) batch_request_ids = get_keys_with_value( stripped_qgraphs, first_value, ) # Re-queue the un-selected requests for request_id in request_value_mapping: if request_id not in batch_request_ids: await self.request_queue.put( ( priorities[request_id], ( request_id, request_value_mapping[request_id], response_queues[request_id], ), ) ) request_value_mapping = { k: v for k, v in request_value_mapping.items() if k in batch_request_ids } # Filter curie mapping to only include matching requests request_curie_mapping = { k: v for k, v in request_curie_mapping.items() if k in batch_request_ids } # Pull first value from request_value_mapping # to use as a template for our merged request merged_request_value = copy.deepcopy( next(iter(request_value_mapping.values())) ) # Remove qnode ids for qnode in merged_request_value.message.query_graph.nodes.values(): qnode.ids = None # Update merged request using curie mapping for curie_mapping in request_curie_mapping.values(): for node_id, node_curies in curie_mapping.items(): node = merged_request_value.message.query_graph.nodes[node_id] if node.ids is None: node.ids = node_curies.copy() else: node.ids.extend(node_curies) # TODO replace qnode.ids with a HashableSet so that this can be safely removed for qnode in merged_request_value.message.query_graph.nodes.values(): if qnode.ids: qnode.ids = list(set(qnode.ids)) response_values = dict() try: # Make request self.logger.info( "[{id}] Sending request made of {subrequests} subrequests ({curies} curies)".format( id=self.id, subrequests=len(request_curie_mapping), curies=" x ".join( str(len(qnode.ids or [])) for qnode in merged_request_value.message.query_graph.nodes.values() ), ) ) self.logger.context = self.id await self.preproc(merged_request_value, self.logger) # TODO rewrite this whole function to use pydantic model merged_request_value = merged_request_value.dict() merged_request_value["submitter"] = "infores:aragorn" merged_request_value = remove_null_values(merged_request_value) async with httpx.AsyncClient() as client: response = await client.post( self.url, json=merged_request_value, timeout=self.timeout, ) if response.status_code == 429: # reset TAT interval = self.request_duration / self.request_qty tat = datetime.datetime.utcnow() + interval # re-queue requests for request_id in request_value_mapping: await self.request_queue.put( ( priorities[request_id], ( request_id, request_value_mapping[request_id], response_queues[request_id], ), ) ) # try again later continue response.raise_for_status() # Parse with reasoner_pydantic to validate response_body = ReasonerResponse.parse_obj(response.json()) await self.postproc(response_body, self.logger) message = response_body.message results = message.results or [] self.logger.info( "[{}] Received response with {} results in {} seconds".format( self.id, len(results), response.elapsed.total_seconds(), ) ) try: if len(request_curie_mapping) == 1: request_id = next(iter(request_curie_mapping)) # Make a copy response_values[request_id] = ReasonerResponse( message=Message() ) response_values[ request_id ].message.query_graph = request_value_mapping[ request_id ].message.query_graph.copy() response_values[request_id].message.knowledge_graph = ( message.knowledge_graph or KnowledgeGraph(nodes={}, edges={}) ).copy() response_values[request_id].message.results = ( message.results or HashableSet(__root__=[]) ).copy() else: # Split using the request_curie_mapping for request_id, curie_mapping in request_curie_mapping.items(): filtered_msg = filter_by_curie_mapping( message, curie_mapping, kp_id=self.id ) filtered_msg.query_graph = request_value_mapping[ request_id ].message.query_graph.copy() response_values[request_id] = ReasonerResponse( message=filtered_msg ) except Exception as err: # Raise more descriptive error message of response message parsing raise Exception( "[{}] Failed to parse message response: {} with Error: {}".format( self.id, response.json(), err, ) ) except ( asyncio.exceptions.TimeoutError, httpx.RequestError, httpx.HTTPStatusError, JSONDecodeError, pydantic.ValidationError, Exception, ) as e: for request_id, curie_mapping in request_curie_mapping.items(): response_values[request_id] = ReasonerResponse( message=request_value_mapping[request_id].message.copy() ) if isinstance(e, asyncio.TimeoutError): self.logger.warning( { "message": f"{self.id} took >60 seconds to respond", "error": str(e), "request": elide_curies(merged_request_value), } ) elif isinstance(e, httpx.ReadTimeout): self.logger.warning( { "message": f"{self.id} took >60 seconds to respond", "error": str(e), "request": log_request(e.request), } ) elif isinstance(e, httpx.RequestError): # Log error self.logger.warning( { "message": f"Request Error contacting {self.id}", "error": str(e), "request": log_request(e.request), } ) elif isinstance(e, httpx.HTTPStatusError): # Log error with response self.logger.warning( { "message": f"Response Error contacting {self.id}", "error": str(e), "request": log_request(e.request), "response": log_response(e.response), } ) elif isinstance(e, JSONDecodeError): # Log error with response self.logger.warning( { "message": f"Received bad JSON data from {self.id}", "request": e.request, "response": e.response.text, "error": str(e), } ) elif isinstance(e, pydantic.ValidationError): self.logger.warning( { "message": f"Received non-TRAPI compliant response from {self.id}", "error": str(e), } ) else: self.logger.warning( { "message": f"Something went wrong while querying {self.id}", "error": str(e), } ) for request_id, response_value in response_values.items(): # Write finished value to DB await response_queues[request_id].put(response_value) # if request_qty == 0 we don't enforce the rate limit if self.request_qty > 0: time_remaining_seconds = ( tat - datetime.datetime.utcnow() ).total_seconds() # Wait for TAT if time_remaining_seconds > 0: await asyncio.sleep(time_remaining_seconds) # Update TAT tat = datetime.datetime.utcnow() + interval
async def query( request: Request, relevance: Optional[float] = Query( 0.0025, description='portion of cooccurrence pubs relevant to question', ), wt_min: Optional[float] = Query( 0.0, description='minimum weight (at 0 pubs)', ), wt_max: Optional[float] = Query( 1.0, description='maximum weight (at inf pubs)', ), p50: Optional[float] = Query( 2.0, description='pubs at 50% of wt_max', ), ) -> Message: """Weight kgraph edges based on metadata. "19 pubs from CTD is a 1, and 2 should at least be 0.5" - cbizon """ message = request.message.dict() def sigmoid(x): """Scale with partial sigmoid - the right (concave down) half. Such that: f(0) = wt_min f(inf) = wt_max f(p50) = 0.5 * wt_max """ a = 2 * (wt_max - wt_min) r = 0.5 * wt_max c = wt_max - 2 * wt_min k = 1 / p50 * (math.log(r + c) - math.log(a - r - c)) return a / (1 + math.exp(-k * x)) - c kgraph = message['knowledge_graph'] node_pubs = { n['id']: n.get('omnicorp_article_count', None) for n in kgraph['nodes'] } all_pubs = 27840000 results = message['results'] # ensure that each edge_binding has a single kg_id for result in results: result['edge_bindings'] = [ eb for ebs in result['edge_bindings'] for eb in ([{ 'qg_id': ebs['qg_id'], 'kg_id': kg_id, } for kg_id in ebs['kg_id']] if isinstance(ebs['kg_id'], list ) else [ebs]) ] # map kedges to edge_bindings krmap = defaultdict(list) for result in results: for eb in result['edge_bindings']: assert isinstance(eb['kg_id'], str) eb['weight'] = eb.get('weight', 1.0) krmap[eb['kg_id']].append(eb) edges = kgraph['edges'] for edge in edges: edge_pubs = edge.get('num_publications', len(edge.get('publications', []))) if edge['type'] == 'literature_co-occurrence': source_pubs = int(node_pubs[edge['source_id']]) target_pubs = int(node_pubs[edge['target_id']]) cov = (edge_pubs / all_pubs ) - (source_pubs / all_pubs) * (target_pubs / all_pubs) cov = max((cov, 0.0)) effective_pubs = cov * all_pubs * relevance else: effective_pubs = edge_pubs + 1 # consider the curation a pub for redge in krmap[edge['id']]: redge['weight'] = redge.get('weight', 1.0) * sigmoid(effective_pubs) message['knowledge_graph'] = kgraph return Message(**message)
] table = "\n\n" table += " Benchmark Name | Output Size (MB) | Total Time (s)\n" table += "--------------------------------------------------------------------------\n" for b in benchmarks: input_messages = [ generate_message_parameterized(**b["params"]) for _ in range(b["msg_count"]) ] start = time.time() combined_msg = Message(results=[]) print(f"Running benchmark {b['name']}") for m in tqdm(input_messages): combined_msg.update(m) end = time.time() # Compute file size print("\nComputing final message size, this may take a while...") output_file_size = len(combined_msg.json()) output_file_size = 0 table += f" {b['name'].center(32)} | {output_file_size/1e6:16} | {end - start:14.2f}\n" print(table)