def _pa_batch_iter(self, db, in_mks=None, ex_mks=None): """Return an iterator over batches of preassembled statements. This avoids the need to load all such statements from the database into RAM at the same time (as this can be quite large). You may limit the set of pa_statements loaded by providing a set/list of matches-keys of the statements you wish to include. """ if in_mks is None and ex_mks is None: db_stmt_iter = db.select_all(db.PAStatements.json, yield_per=self.batch_size) elif ex_mks is None and in_mks: db_stmt_iter = db.select_all(db.PAStatements.json, db.PAStatements.mk_hash.in_(in_mks), yield_per=self.batch_size) elif in_mks is None and ex_mks: db_stmt_iter = db.select_all( db.PAStatements.json, db.PAStatements.mk_hash.notin_(ex_mks), yield_per=self.batch_size) elif in_mks and ex_mks: db_stmt_iter = db.select_all( db.PAStatements.json, db.PAStatements.mk_hash.notin_(ex_mks), db.PAStatements.mk_hash.in_(in_mks), yield_per=self.batch_size) else: db_stmt_iter = db.select_all(db.PAStatements.json, yield_per=self.batch_size) pa_stmts = (_stmt_from_json(s_json) for s_json, in db_stmt_iter) return batch_iter(pa_stmts, self.batch_size, return_func=list)
def _raw_sid_stmt_iter(self, db, id_set, do_enumerate=False): """Return a generator over statements with the given database ids.""" def _fixed_raw_stmt_from_json(s_json, tr): stmt = _stmt_from_json(s_json) if tr is not None: stmt.evidence[0].pmid = tr.pmid stmt.evidence[0].text_refs = { k: v for k, v in tr.__dict__.items() if not k.startswith('_') } return stmt i = 0 for stmt_id_batch in batch_iter(id_set, self.batch_size): subres = (db.filter_query( [db.RawStatements.id, db.RawStatements.json, db.TextRef], db.RawStatements.id.in_(stmt_id_batch)).outerjoin( db.Reading).outerjoin(db.TextContent).outerjoin( db.TextRef).yield_per(self.batch_size // 10)) data = [(sid, _fixed_raw_stmt_from_json(s_json, tr)) for sid, s_json, tr in subres] if do_enumerate: yield i, data i += 1 else: yield data
def _get_unique_statements(self, db, stmt_tpls, num_stmts, mk_done=None): """Get the unique Statements from the raw statements.""" if mk_done is None: mk_done = set() new_mk_set = set() stmt_batches = batch_iter(stmt_tpls, self.batch_size, return_func=list) num_batches = num_stmts / self.batch_size for i, stmt_tpl_batch in enumerate(stmt_batches): self._log("Processing batch %d/%d of %d/%d statements." % (i, num_batches, len(stmt_tpl_batch), num_stmts)) unique_stmts, evidence_links = \ self._make_unique_statement_set(stmt_tpl_batch) new_unique_stmts = [] for s in unique_stmts: s_hash = s.get_hash(shallow=True) if s_hash not in (mk_done | new_mk_set): new_mk_set.add(s_hash) new_unique_stmts.append(s) insert_pa_stmts(db, new_unique_stmts) db.copy('raw_unique_links', evidence_links, ('pa_stmt_mk_hash', 'raw_stmt_id')) self._log("Added %d new pa statements into the database." % len(new_mk_set)) return new_mk_set
def get_unique_text_refs(): """Get unique INDRA DB TextRef IDs for all identifiers in CORD19. Queries TextRef IDs with PMIDs, PMCIDs, and DOIs from CORD19, then deduplicates to obtain a unique set of TextRefs. Returns ------- set of ints Unique TextRef IDs. """ pmcids = get_ids('pmcid') pmids = [fix_pmid(pmid) for pmid in get_ids('pubmed_id')] dois = [fix_doi(doi) for doi in get_ids('doi')] # Get unique text_refs from the DB db = get_primary_db() print("Getting TextRefs by PMCID") tr_pmcids = db.select_all(db.TextRef.id, db.TextRef.pmcid_in(pmcids)) print("Getting TextRefs by PMID") tr_pmids = db.select_all(db.TextRef.id, db.TextRef.pmid_in(pmids)) tr_dois = [] for ix, doi_batch in enumerate(batch_iter(dois, 10000)): print("Getting Text Refs by DOI batch", ix) tr_doi_batch = db.select_all( db.TextRef.id, db.TextRef.doi_in(doi_batch, filter_ids=True)) tr_dois.extend(tr_doi_batch) ids = set([ res.id for res_list in (tr_dois, tr_pmcids, tr_pmids) for res in res_list ]) print(len(ids), "unique TextRefs in DB") trs = db.select_all(db.TextRef, db.TextRef.id.in_(ids)) return trs
def get_indradb_pa_stmts(): """Get preassembled INDRA Stmts for PMC articles from INDRA DB. DEPRECATED. Get Raw Statements instead. """ # Get the list of all PMCIDs from the corpus metadata pmcids = get_ids('pmcid') paper_refs = [('pmcid', p) for p in pmcids] stmt_jsons = [] batch_size = 1000 start = time.time() for batch_ix, paper_batch in enumerate(batch_iter(paper_refs, batch_size)): if batch_ix <= 5: continue papers = list(paper_batch) print("Querying DB for statements for %d papers" % batch_size) batch_start = time.time() result = get_statement_jsons_from_papers(papers) batch_elapsed = time.time() - batch_start batch_jsons = [ stmt_json for stmt_hash, stmt_json in result['statements'].items() ] print("Returned %d stmts in %f sec" % (len(batch_jsons), batch_elapsed)) batch_stmts = stmts_from_json(batch_jsons) ac.dump_statements(batch_stmts, 'batch_%02d.pkl' % batch_ix) stmt_jsons += batch_jsons elapsed = time.time() - start print("Total time: %f sec, %d papers" % (elapsed, len(paper_refs))) stmts = stmts_from_json(stmt_jsons) ac.dump_statements(stmts, 'cord19_pmc_stmts.pkl') return stmt_jsons
def populate(self, db): # Turn the list of dicts into a set of tuples tr_data_set = { tuple([entry[id_type] for id_type in self.tr_cols]) for entry in self.tr_data } # Filter_text_refs will figure out which articles are already in the # TextRef table and will update them with any new metadata; # filtered_tr_records are the ones that need to be added to the DB filtered_tr_records = set() flawed_tr_records = set() for ix, tr_batch in enumerate(batch_iter(tr_data_set, 10000)): print("Getting Text Refs using pmid/pmcid/doi, batch", ix) filt_batch, flaw_batch = \ self.filter_text_refs(db, set(tr_batch), primary_id_types=['pmid', 'pmcid', 'doi']) filtered_tr_records |= set(filt_batch) flawed_tr_records |= set(flaw_batch) trs_to_skip = {rec for cause, rec in flawed_tr_records} # Why did the original version not skip in case of disagreeing # pmid or doi? #pmcids_to_skip = {rec[self.tr_cols.index('pmcid')] # for cause, rec in flawed_tr_records # if cause in ['pmcid', 'over_match_input', # 'over_match_db']} # Then we put together the updated text content data if len(trs_to_skip) is not 0: mod_tc_data = [ tc for tc in self.tc_data if (tc.get('pmid'), tc.get('pmcid'), tc.get('doi')) not in trs_to_skip ] else: mod_tc_data = self.tc_data # Upload TextRef data for articles NOT already in the DB logger.info('Adding %d new text refs...' % len(filtered_tr_records)) if filtered_tr_records: self.copy_into_db(db, 'text_ref', filtered_tr_records, self.tr_cols) gatherer.add('refs', len(filtered_tr_records)) # Process the text content data filtered_tc_records, flawed_tcs = \ self.filter_text_content(db, mod_tc_data) # Upload the text content data. logger.info('Adding %d more text content entries...' % len(filtered_tc_records)) self.copy_into_db(db, 'text_content', filtered_tc_records, self.tc_cols) gatherer.add('content', len(filtered_tc_records)) return { 'filtered_tr_records': filtered_tr_records, 'flawed_tr_records': flawed_tr_records, 'mod_tc_data': mod_tc_data, 'filtered_tc_records': filtered_tc_records }
def download_statements(hashes): """Download the INDRA Statements corresponding to a set of hashes. """ stmts_by_hash = {} for group in tqdm.tqdm(batch_iter(hashes, 200), total=int(len(hashes) / 200)): idbp = indra_db_rest.get_statements_by_hash(list(group), ev_limit=10) for stmt in idbp.statements: stmts_by_hash[stmt.get_hash()] = stmt return stmts_by_hash
def download_statements(df, ev_limit=5): """Download the INDRA Statements corresponding to entries in a data frame. """ all_stmts = [] for idx, group in enumerate(batch_iter(df.hash, 500)): logger.info('Getting statement batch %d' % idx) idbp = indra_db_rest.get_statements_by_hash(list(group), ev_limit=ev_limit) all_stmts += idbp.statements return all_stmts
def get_raw_statements_for_pmids(pmids, mode='all', batch_size=100): """Return EmmaaStatements based on extractions from given PMIDs. Parameters ---------- pmids : set or list of str A set of PMIDs to find raw INDRA Statements for in the INDRA DB. mode : 'all' or 'distilled' The 'distilled' mode makes sure that the "best", non-redundant set of raw statements are found across potentially redundant text contents and reader versions. The 'all' mode doesn't do such distillation but is significantly faster. batch_size : Optional[int] Determines how many PMIDs to fetch statements for in each iteration. Default: 100. Returns ------- dict A dict keyed by PMID with values INDRA Statements obtained from the given PMID. """ db = get_db('primary') logger.info(f'Getting raw statements for {len(pmids)} PMIDs') all_stmts = defaultdict(list) for pmid_batch in tqdm.tqdm(batch_iter(pmids, return_func=set, batch_size=batch_size), total=len(pmids) / batch_size): if mode == 'distilled': clauses = [ db.TextRef.pmid.in_(pmid_batch), db.TextContent.text_ref_id == db.TextRef.id, db.Reading.text_content_id == db.TextContent.id, db.RawStatements.reading_id == db.Reading.id ] distilled_stmts = distill_stmts(db, get_full_stmts=True, clauses=clauses) for stmt in distilled_stmts: all_stmts[stmt.evidence[0].pmid].append(stmt) else: id_stmts = \ get_raw_stmt_jsons_from_papers(pmid_batch, id_type='pmid', db=db) for pmid, stmt_jsons in id_stmts.items(): all_stmts[pmid] += stmts_from_json(stmt_jsons) all_stmts = dict(all_stmts) return all_stmts
def _raw_sid_stmt_iter(self, db, id_set, do_enumerate=False): """Return a generator over statements with the given database ids.""" i = 0 for stmt_id_batch in batch_iter(id_set, self.batch_size): subres = db.select_all( [db.RawStatements.id, db.RawStatements.json], db.RawStatements.id.in_(stmt_id_batch), yield_per=self.batch_size // 10) if do_enumerate: yield i, [(sid, _stmt_from_json(s_json)) for sid, s_json in subres] i += 1 else: yield [(sid, _stmt_from_json(s_json)) for sid, s_json in subres]
def _process_pa_statement_res_wev(db, stmt_iterable, count=1000, fix_refs=True): warnings.warn(('This module is being taken out of service, as the tools ' 'have become deprecated. Moreover, the service has been ' 're-implemented to use newer tools as best as possible, ' 'but some results may be subtly different.'), DeprecationWarning) # Iterate over the batches to create the statement objects. stmt_dict = {} ev_dict = {} raw_stmt_dict = {} total_ev = 0 for stmt_pair_batch in batch_iter(stmt_iterable, count): # Instantiate the PA statement objects, and record the uuid # evidence (raw statement) links. raw_stmt_objs = [] for pa_stmt_db_obj, raw_stmt_db_obj in stmt_pair_batch: k = pa_stmt_db_obj.mk_hash if k not in stmt_dict.keys(): stmt_dict[k] = get_statement_object(pa_stmt_db_obj) ev_dict[k] = [ raw_stmt_db_obj.id, ] else: ev_dict[k].append(raw_stmt_db_obj.id) raw_stmt_objs.append(raw_stmt_db_obj) total_ev += 1 logger.info("Up to %d pa statements, with %d pieces of " "evidence in all." % (len(stmt_dict), total_ev)) # Instantiate the raw statements. raw_stmt_sid_tpls = get_raw_stmts_frm_db_list(db, raw_stmt_objs, fix_refs, with_sids=True) raw_stmt_dict.update({sid: s for sid, s in raw_stmt_sid_tpls}) logger.info("Processed %d raw statements." % len(raw_stmt_sid_tpls)) # Attach the evidence logger.info("Inserting evidence.") for k, sid_list in ev_dict.items(): stmt_dict[k].evidence = [ raw_stmt_dict[sid].evidence[0] for sid in sid_list ] return stmt_dict
def get_mutated_genes(self): """Return dict of gene mutation frequencies based on TCGA studies.""" if self.mutation_cache: logger.info('Loading mutations from %s' % self.mutation_cache) with open(self.mutation_cache, 'r') as fh: self.mutations = json.load(fh) else: logger.info('Getting mutations from cBio web service') mutations = {} for tcga_study_name in tcga_studies[self.tcga_study_prefix]: for idx, hgnc_name_batch in \ enumerate(batch_iter(hgnc_ids.keys(), 200)): logger.info('Fetching mutations for %s and gene batch %s' % (tcga_study_name, idx)) patient_mutations = \ cbio_client.get_profile_data(tcga_study_name, hgnc_name_batch, 'mutation') # e.g. 'ICGC_0002_TD': {'BRAF': None, 'KRAS': 'G12D'} for patient, gene_mut_dict in patient_mutations.items(): # 'BRAF': None for gene, mutated in gene_mut_dict.items(): if mutated is not None: try: mutations[gene] += 1 except KeyError: mutations[gene] = 1 self.mutations = mutations # Normalize mutations by length self.norm_mutations = {} for gene_name, num_muts in self.mutations.items(): hgnc_id = get_hgnc_id(gene_name) up_id = get_uniprot_id(hgnc_id) if not up_id: logger.warning("Could not get Uniprot ID for HGNC symbol %s " "with HGNC ID %s" % (gene_name, hgnc_id)) length = 500 # a guess at a default else: length = uniprot_client.get_length(up_id) if not length: logger.warning("Could not get length for Uniprot " "ID %s" % up_id) length = 500 # a guess at a default self.norm_mutations[gene_name] = num_muts / float(length) return self.mutations, self.norm_mutations
def _process_pa_statement_res_nev(stmt_iterable, count=1000): warnings.warn(('This module is being taken out of service, as the tools ' 'have become deprecated. Moreover, the service has been ' 're-implemented to use newer tools as best as possible, ' 'but some results may be subtly different.'), DeprecationWarning) # Iterate over the batches to create the statement objects. stmt_dict = {} for stmt_pair_batch in batch_iter(stmt_iterable, count): # Instantiate the PA statement objects. for pa_stmt_db_obj in stmt_pair_batch: k = pa_stmt_db_obj.mk_hash if k not in stmt_dict.keys(): stmt_dict[k] = get_statement_object(pa_stmt_db_obj) logger.info("Up to %d pa statements in all." % len(stmt_dict)) return stmt_dict
def download_statements(df, beliefs, ev_limit=5): """Download the INDRA Statements corresponding to entries in a data frame. """ all_stmts = [] unique_hashes = list(set(df.stmt_hash)) batches = list(batch_iter(unique_hashes, 500)) logger.info('Getting %d unique hashes from db' % len(unique_hashes)) for group in tqdm.tqdm(batches): idbp = indra_db_rest.get_statements_by_hash(list(group), ev_limit=ev_limit) all_stmts += idbp.statements for stmt in all_stmts: belief = beliefs.get(stmt.get_hash()) if belief is None: logger.info('No belief found for %s' % str(stmt)) continue stmt.belief = belief return all_stmts
def select_all_batched(self, batch_size, tbls, *args, skip_idx=None, order_by=None): """Load the results of a query in batches of size batch_size. Note that this differs from using yeild_per in that the results are not returned as a single iterable, but as an iterator of iterables. Note also that the order of results, and thus the contents of offsets, may vary for large queries unless an explicit order_by clause is added to the query. """ q = self.filter_query(tbls, *args) if order_by: q = q.order_by(order_by) res_iter = q.yield_per(batch_size) for i, batch in enumerate(batch_iter(res_iter, batch_size)): if i != skip_idx: yield i, batch
def get_stmts_pmids_mesh(subject, stmt_type, object_list): stmts = [] for obj in object_list: idrp = idr.get_statements(subject=subject, object=obj, stmt_type=stmt_type, ev_limit=10000) stmts += idrp.statements # Collect the PMIDs for the stmts pmids = [e.pmid for s in stmts for e in s.evidence] mesh_terms = [] for batch in batch_iter(pmids, 200): pmid_list = list(batch) print("Retrieving metadata for %d articles" % len(pmid_list)) metadata = get_metadata_for_ids(pmid_list) for pmid, pmid_meta in metadata.items(): mesh_terms += [d['mesh'] for d in pmid_meta['mesh_annotations']] return (stmts, pmids, mesh_terms)
def get_mutated_genes(self): """Return dict of gene mutation frequencies based on TCGA studies.""" if self.mutation_cache: logger.info('Loading mutations from %s' % self.mutation_cache) with open(self.mutation_cache, 'r') as fh: self.mutations = json.load(fh) else: logger.info('Getting mutations from cBio web service') mutations = {} for tcga_study_name in tcga_studies[self.tcga_study_prefix]: for idx, hgnc_name_batch in \ enumerate(batch_iter(hgnc_ids.keys(), 200)): logger.info('Fetching mutations for %s and gene batch %s' % (tcga_study_name, idx)) patient_mutations = \ cbio_client.get_profile_data(tcga_study_name, hgnc_name_batch, 'mutation') # e.g. 'ICGC_0002_TD': {'BRAF': None, 'KRAS': 'G12D'} for patient, gene_mut_dict in patient_mutations.items(): # 'BRAF': None for gene, mutated in gene_mut_dict.items(): if mutated is not None: try: mutations[gene] += 1 except KeyError: mutations[gene] = 1 self.mutations = mutations # Normalize mutations by length self.norm_mutations = {} for gene_name, num_muts in self.mutations.items(): self.norm_mutations[gene_name] = \ self.normalize_mutation_count(gene_name, num_muts) return self.mutations, self.norm_mutations
def get_statements(clauses, count=1000, do_stmt_count=True, db=None, preassembled=True, with_support=False): """Select statements according to a given set of clauses. Parameters ---------- clauses : list list of sqlalchemy WHERE clauses to pass to the filter query. count : int Number of statements to retrieve and process in each batch. do_stmt_count : bool Whether or not to perform an initial statement counting step to give more meaningful progress messages. db : :py:class:`DatabaseManager` Optionally specify a database manager that attaches to something besides the primary database, for example a local database instance. preassembled : bool If true, statements will be selected from the table of pre-assembled statements. Otherwise, they will be selected from the raw statements. Default is True. Returns ------- list of Statements from the database corresponding to the query. """ if db is None: db = get_primary_db() stmts_tblname = 'pa_statements' if preassembled else 'raw_statements' if not preassembled: stmts = [] q = db.filter_query(stmts_tblname, *clauses) if do_stmt_count: logger.info("Counting statements...") num_stmts = q.count() logger.info("Total of %d statements" % num_stmts) db_stmts = q.yield_per(count) for subset in batch_iter(db_stmts, count): stmts.extend(make_raw_stmts_from_db_list(db, subset)) if do_stmt_count: logger.info("%d of %d statements" % (len(stmts), num_stmts)) else: logger.info("%d statements" % len(stmts)) else: logger.info("Getting preassembled statements.") # Get pairs of pa statements with their supporting statement (as long as # the number of supporting statements). clauses += db.join(db.PAStatements, db.RawStatements) pa_raw_stmt_pairs = db.select_all([db.PAStatements, db.RawStatements], *clauses, yield_per=count) # Iterate over the batches to create the statement objects. stmt_dict = {} ev_dict = {} raw_stmt_dict = {} for stmt_pair_batch in batch_iter(pa_raw_stmt_pairs, count): # Instantiate the PA statement objects, and record the uuid evidence # (raw statement) links. raw_stmt_obj_list = [] for pa_stmt_db_obj, raw_stmt_db_obj in stmt_pair_batch: k = pa_stmt_db_obj.mk_hash if k not in stmt_dict.keys(): stmt_dict[k] = _get_statement_object(pa_stmt_db_obj) ev_dict[k] = [raw_stmt_db_obj.uuid,] else: ev_dict[k].append(raw_stmt_db_obj.uuid) raw_stmt_obj_list.append(raw_stmt_db_obj) logger.info("Up to %d pa statements, with %d pieces of evidence in " "all." % (len(stmt_dict), len(ev_dict))) # Instantiate the raw statements. raw_stmts = make_raw_stmts_from_db_list(db, raw_stmt_obj_list) raw_stmt_dict.update({s.uuid: s for s in raw_stmts}) logger.info("Processed %d raw statements." % len(raw_stmts)) # Attach the evidence logger.info("Inserting evidence.") for k, uuid_list in ev_dict.items(): stmt_dict[k].evidence = [raw_stmt_dict[uuid].evidence[0] for uuid in uuid_list] # Populate the supports/supported by fields. if with_support: logger.info("Populating support links.") support_links = db.filter_query( [db.PASupportLinks.supported_mk_hash, db.PASupportLinks.supporting_mk_hash], or_(db.PASupportLinks.supported_mk_hash.in_(stmt_dict.keys()), db.PASupportLinks.supporting_mk_hash.in_(stmt_dict.keys())) ).distinct().yield_per(count) for supped_hash, supping_hash in support_links: supped_stmt = stmt_dict.get(supped_hash, Unresolved(shallow_hash=supped_hash)) supping_stmt = stmt_dict.get(supping_hash, Unresolved(shallow_hash=supping_hash)) if not isinstance(supped_stmt, str): supped_stmt.supported_by.append(supping_stmt) if not isinstance(supping_stmt, str): supping_stmt.supports.append(supped_stmt) stmts = list(stmt_dict.values()) logger.info("In all, there are %d pa statements." % len(stmts)) return stmts
ambigs = get_filter_ambigs(str_counts) all_ambigs_pmids = get_all_pmids(ambigs) # Construct list of ambiguities #print('Found a total of %d ambiguities.' % len(ambigs)) # Learn models param_grid = { 'C': [10.0], 'max_features': [100, 1000], 'ngram_range': [(1, 2)] } models = {} if args.nproc == 1: for idx, ambig_terms_batch in \ enumerate(batch_iter(all_ambigs_pmids, 10)): pickle_name = 'gilda_ambiguities_hgnc_mesh_%d.pkl' % idx models = learn_batch(ambig_terms_batch) with open(pickle_name, 'wb') as fh: pickle.dump(models, fh) else: pool = Pool(args.nproc) fun = functools.partial(learn_model, params=param_grid) pkl_idx = 0 models = {} for count, model in enumerate( pool.imap_unordered(fun, all_ambigs_pmids, chunksize=10)): print('#### %d ####' % count) if model is None: print('Model is None, skipping') else:
def decorator(*args, **kwargs): tracker = LogTracker() start_time = datetime.now() logger.info("Got query for %s at %s!" % (f.__name__, start_time)) query = request.args.copy() offs = query.pop('offset', None) ev_lim = query.pop('ev_limit', None) best_first_str = query.pop('best_first', 'true') best_first = True if best_first_str.lower() == 'true' \ or best_first_str else False do_stream_str = query.pop('stream', 'false') do_stream = True if do_stream_str == 'true' else False max_stmts = min(int(query.pop('max_stmts', MAX_STATEMENTS)), MAX_STATEMENTS) format = query.pop('format', 'json') api_key = query.pop('api_key', None) logger.info("Running function %s after %s seconds." % (f.__name__, sec_since(start_time))) result = f(query, offs, max_stmts, ev_lim, best_first, *args, **kwargs) logger.info("Finished function %s after %s seconds." % (f.__name__, sec_since(start_time))) # Redact elsevier content for those without permission. if api_key is None or not _has_elsevier_auth(api_key): for stmt_json in result['statements'].values(): for ev_json in stmt_json['evidence']: if get_source(ev_json) == 'elsevier': text = ev_json['text'] if len(text) > 200: ev_json['text'] = text[:200] + REDACT_MESSAGE logger.info("Finished redacting evidence for %s after %s seconds." % (f.__name__, sec_since(start_time))) result['offset'] = offs result['evidence_limit'] = ev_lim result['statement_limit'] = MAX_STATEMENTS result['statements_returned'] = len(result['statements']) if format == 'html': stmts_json = result.pop('statements') ev_totals = result.pop('evidence_totals') stmts = stmts_from_json(stmts_json.values()) html_assembler = HtmlAssembler(stmts, result, ev_totals, title='INDRA DB REST Results', db_rest_url=request.url_root[:-1]) content = html_assembler.make_model() if tracker.get_messages(): level_stats = [ '%d %ss' % (n, lvl.lower()) for lvl, n in tracker.get_level_stats().items() ] msg = ' '.join(level_stats) content = html_assembler.append_warning(msg) mimetype = 'text/html' else: # Return JSON for all other values of the format argument result.update(tracker.get_level_stats()) content = json.dumps(result) mimetype = 'application/json' if do_stream: # Returning a generator should stream the data. resp_json_bts = content gen = batch_iter(resp_json_bts, 10000) resp = Response(gen, mimetype=mimetype) else: resp = Response(content, mimetype=mimetype) logger.info("Exiting with %d statements with %d evidence of size " "%f MB after %s seconds." % (result['statements_returned'], result['total_evidence'], sys.getsizeof(resp.data) / 1e6, sec_since(start_time))) return resp
def distill_stmts(db, get_full_stmts=False, clauses=None, num_procs=1, delete_duplicates=True, weed_evidence=True, batch_size=1000): """Get a corpus of statements from clauses and filters duplicate evidence. Parameters ---------- db : :py:class:`DatabaseManager` A database manager instance to access the database. get_full_stmts : bool By default (False), only Statement ids (the primary index of Statements on the database) are returned. However, if set to True, serialized INDRA Statements will be returned. Note that this will in general be VERY large in memory, and therefore should be used with caution. clauses : None or list of sqlalchemy clauses By default None. Specify sqlalchemy clauses to reduce the scope of statements, e.g. `clauses=[db.Statements.type == 'Phosphorylation']` or `clauses=[db.Statements.uuid.in_([<uuids>])]`. num_procs : int Select the number of process that can be used. delete_duplicates : bool Choose whether you want to delete the statements that are found to be duplicates. weed_evidence : bool If True, evidence links that exist for raw statements that now have better alternatives will be removed. If False, such links will remain, which may cause problems in incremental pre-assembly. Returns ------- stmt_ret : set A set of either statement ids or serialized statements, depending on `get_full_stmts`. """ if delete_duplicates: logger.info("Looking for ids from existing links...") linked_sids = {sid for sid, in db.select_all(db.RawUniqueLinks.raw_stmt_id)} else: linked_sids = set() # Get de-duplicated Statements, and duplicate uuids, as well as uuid of # Statements that have been improved upon... logger.info("Sorting reading statements...") stmt_nd = _get_reading_statement_dict(db, clauses, get_full_stmts) stmts, duplicate_sids, bettered_duplicate_sids = \ _get_filtered_rdg_statements(stmt_nd, get_full_stmts, linked_sids) logger.info("After filtering reading: %d unique statements, %d exact " "duplicates, %d with results from better resources available." % (len(stmts), len(duplicate_sids), len(bettered_duplicate_sids))) assert not linked_sids & duplicate_sids, linked_sids & duplicate_sids del stmt_nd # This takes up a lot of memory, and is done being used. db_stmts, db_duplicates = \ _get_filtered_db_statements(db, get_full_stmts, clauses, linked_sids, num_procs) stmts |= db_stmts duplicate_sids |= db_duplicates logger.info("After filtering database statements: %d unique, %d duplicates." % (len(stmts), len(duplicate_sids))) assert not linked_sids & duplicate_sids, linked_sids & duplicate_sids # Remove support links for statements that have better versions available. bad_link_sids = bettered_duplicate_sids & linked_sids if len(bad_link_sids) and weed_evidence: logger.info("Removing bettered evidence links...") rm_links = db.select_all( db.RawUniqueLinks, db.RawUniqueLinks.raw_stmt_id.in_(bad_link_sids) ) db.delete_all(rm_links) # Delete exact duplicates if len(duplicate_sids) and delete_duplicates: logger.info("Deleting duplicates...") for dup_id_batch in batch_iter(duplicate_sids, batch_size, set): bad_stmts = db.select_all(db.RawStatements, db.RawStatements.id.in_(dup_id_batch)) bad_sid_set = {s.id for s in bad_stmts} bad_agents = db.select_all(db.RawAgents, db.RawAgents.stmt_id.in_(bad_sid_set)) logger.info("Deleting %d agents associated with redundant raw " "statements." % len(bad_agents)) db.delete_all(bad_agents) logger.info("Deleting %d redundant raw statements." % len(bad_stmts)) db.delete_all(bad_stmts) return stmts
def match_correlations(corr_z: pd.DataFrame, sd_range: Tuple[float, Union[float, None]], script_settings: Dict[str, Union[str, int, float]], graph_filepath: str, z_corr_filepath: str, apriori_explained: Optional[Dict[str, str]] = None, graph_type: str = 'unsigned', allowed_ns: Optional[List[str]] = None, allowed_sources: Optional[List[str]] = None, is_a_part_of: Optional[List[str]] = None, expl_funcs: Optional[List[str]] = None, reactome_filepath: Optional[str] = None, indra_date: Optional[str] = None, info: Optional[Dict[str, Any]] = None, depmap_date: Optional[str] = None, n_chunks: Optional[int] = 8, immediate_only: Optional[bool] = False, return_unexplained: Optional[bool] = False, reactome_dict: Optional[Dict[str, Any]] = None, subset_list: Optional[List[Union[str, int]]] = None): """The main loop for matching correlations with INDRA explanations Note that indranet is assumed to be a global variable that needs to be set outside of this function and be set to global Parameters ---------- corr_z : pd.DataFrame The pre-processed correlation matrix. No more processing of the matrix should have to be done here, i.e. it should already have filtered the correlations to the proper SD ranges and removed the genes that are not applicable for this explanation, self correlations should also have been removed. indranet : nx.DiGraph The graph representation of the indra network. Each edge should have an attribute named 'statements' containing a list of sources supporting that edge. If signed search, indranet is expected to be an nx.MultiDiGraph with edges keys by (gene, gene, sign) tuples. sd_range : tuple[float] The SD ranges that the corr_z is filtered to script_settings : Dict[str, Union[str, int, float]] Dictionary with script settings for the purpose of book keeping graph_filepath : str File path to the graph used z_corr_filepath : str File path to the correlation matrix used reactome_filepath : Optional[str] File path to the reactome data subset_list : Optional[List[Union[str, int]]] If True, check all combinations of off-diagonal values from the correlation matrix, i.e. check both (a, b) and (b, a). Default: False. Returns ------- depmap_analysis.util.statistics.DepMapExplainer An instance of the DepMapExplainer class containing the explanations for the correlations. """ # Map each expl type to a function that handles that explanation if not expl_funcs: # No function names provided, use all explanation functions logger.info('All explanation types used') expl_types = { funcname_to_colname[func_name]: func for func_name, func in expl_functions.items() } else: # Map function names to functions, check if expl_types = {} for func_name in expl_funcs: if func_name not in expl_functions: logger.warning(f'{func_name} does not map to a registered ' f'explanation function. Allowed functions ' f'{", ".join(expl_functions.keys())}') else: expl_types[funcname_to_colname[func_name]] = \ expl_functions[func_name] if not len(expl_types): raise ValueError('No explanation functions provided') bool_columns = ('not_in_graph', 'explained') + tuple(expl_types.keys()) stats_columns = id_columns + bool_columns expl_cols = expl_columns apriori_explained = apriori_explained or {} logger.info(f'Doing correlation matching with {graph_type} graph') # Get options if allowed_ns is not None: allowed_ns = {n.lower() for n in allowed_ns} logger.info('Only allowing the following namespaces: %s' % ', '.join(allowed_ns)) if allowed_sources is not None: allowed_sources = {s.lower() for s in allowed_sources} logger.info('Only allowing the following sources: %s' % ', '.join(allowed_sources)) is_a_part_of = is_a_part_of or [] # Try to get dates of files from file names and file info ymd_now = datetime.now().strftime('%Y%m%d') indra_date = indra_date or ymd_now depmap_date = depmap_date or ymd_now logger.info('Calculating number of pairs to check...') estim_pairs = get_pairs(corr_z, subset_list=subset_list) logger.info(f'Starting workers at {datetime.now().strftime("%H:%M:%S")} ' f'with about {estim_pairs} pairs to check') tstart = time() with mp.Pool() as pool: MAX_SUB = 512 n_sub = min(n_chunks, MAX_SUB) chunksize = get_chunk_size(n_sub, estim_pairs) # Pick one more so we don't do more than MAX_SUB chunksize += 1 if n_sub == MAX_SUB else 0 chunk_iter = batch_iter(iterator=corr_matrix_to_generator( corr_z, subset_list=subset_list), batch_size=chunksize, return_func=list) for chunk in chunk_iter: pool.apply_async( func=_match_correlation_body, # args should match the args for func args=(chunk, expl_types, stats_columns, expl_cols, bool_columns, graph_type, apriori_explained, allowed_ns, allowed_sources, is_a_part_of, immediate_only, return_unexplained, reactome_dict), callback=success_callback, error_callback=error_callback) logger.info('Done submitting work to pool workers') pool.close() pool.join() print(f'Execution time: {time() - tstart} seconds') print(f'Done at {datetime.now().strftime("%H:%M:%S")}') # Here initialize a DepMapExplainer and append the result for the # different processes explainer = DepMapExplainer(stats_columns=stats_columns, expl_columns=expl_columns, graph_filepath=graph_filepath, z_corr_filepath=z_corr_filepath, reactome_filepath=reactome_filepath, info={ 'indra_network_date': indra_date, 'depmap_date': depmap_date, 'sd_range': sd_range, 'graph_type': graph_type, **(info or {}) }, script_settings=script_settings) logger.info(f'Generating DepMapExplainer with output from ' f'{len(output_list)} results') for stats_dict, expl_dict in output_list: explainer.stats_df = explainer.stats_df.append(other=pd.DataFrame( data=stats_dict)) explainer.expl_df = explainer.expl_df.append(other=pd.DataFrame( data=expl_dict)) return explainer
def filter_text_content(self, db, tc_data): """Link Text Content entries to corresponding Text Refs and filter out entries already in the database.""" if not len(tc_data): return [] logger.info("Beginning to filter text content...") tr_list = [] # Step 1: Build a dictionary matching IDs to text ref objects for ix, tc_batch in enumerate(batch_iter(tc_data, 5000)): # Get the sets of IDs for this batch # Only use the generator once! ids = [(tc['pmid'], tc['pmcid'], tc['doi']) for tc in tc_batch] pmids, pmcids, dois = list(zip(*ids)) # Remove any Nones and convert to sets pmid_set = set([i for i in pmids if i is not None]) pmcid_set = set([i for i in pmcids if i is not None]) doi_set = set([i for i in dois if i is not None]) # Get all TextRefs for the CORD19 IDs logger.debug("Getting text refs for CORD19 articles") tr_list += db.select_all( db.TextRef, sql_exp.or_(db.TextRef.pmid_in(pmid_set, filter_ids=True), db.TextRef.pmcid_in(pmcid_set, filter_ids=True), db.TextRef.doi_in(doi_set, filter_ids=True))) # Next, build dictionaries mapping IDs back to TextRef objects so # that we can link records in tc_data to TextRefs trs_by_doi = defaultdict(set) trs_by_pmc = defaultdict(set) trs_by_pmid = defaultdict(set) for tr in tr_list: if tr.doi: trs_by_doi[tr.doi].add(tr) if tr.pmcid: trs_by_pmc[tr.pmcid].add(tr) if tr.pmid: trs_by_pmid[tr.pmid].add(tr) # Now, build a new dictionary of text content including the TRIDs # rather than pmid/pmcid/doi # A list of dictionaries each containing: tr_id, source, format and # text_type flawed_tcs = set() tc_data_by_tr = [] for tc_entry in tc_data: by_tr_entry = {} for field in ('source', 'format', 'text_type', 'content'): by_tr_entry[field] = tc_entry[field] tr_ids_for_tc = set() for id_type, trs_by_id in (('pmid', trs_by_pmid), ('pmcid', trs_by_pmc), ('doi', trs_by_doi)): tr_set = trs_by_id.get(tc_entry[id_type]) if tr_set is not None: # assert len(tr_set) == 1 if len(tr_set) != 1: logger.warning( '%s %s is associated with multiple TextRefs: %s' % (id_type, tc_entry[id_type], tr_set)) continue tr = list(tr_set)[0] tr_ids_for_tc.add(tr.id) # Because this function is called using tc_data that has already # been filtered by text ref, we should always get unambiguous # matches to text_refs here. if len(tr_ids_for_tc) != 1: log_entry = (tc_entry['pmid'], tc_entry['pmcid'], tc_entry['doi'], tc_entry['cord_uid'], tuple(tr_ids_for_tc)) logger.warning('Missing or ambiguous match to text ref: %s' % str(log_entry)) flawed_tcs.add(log_entry) else: tr_id = list(tr_ids_for_tc)[0] by_tr_entry['trid'] = tr_id tc_data_by_tr.append(by_tr_entry) # Step 2: Get existing Text Content objects corresponding to the # the given text refs with the same format and source. # This should be a very small list, in general. existing_tc_records = [] for source, text_type in (('cord19_abstract', 'abstract'), ('cord19_pmc_xml', 'fulltext'), ('cord19_pdf', 'fulltext')): logger.debug('Finding existing text content from db for ' 'source type %s' % source) tc_by_source = [ tc_entry for tc_entry in tc_data_by_tr if tc_entry['source'] == source ] existing_tcs = db.select_all( db.TextContent, db.TextContent.text_ref_id.in_([ tc['trid'] for tc in tc_by_source ]), db.TextContent.source == source, db.TextContent.format == 'text', db.TextContent.text_type == text_type) # Reformat Text Content objects to list of tuples existing_tc_records += [(tc.text_ref_id, tc.source, tc.format, tc.text_type) for tc in existing_tcs] logger.debug("Found %d existing records on the db for %s." % (len(existing_tc_records), source)) # Convert list of dicts into a list of tuples tc_records = [] for tc_entry in tc_data_by_tr: tc_records.append( (tc_entry['trid'], tc_entry['source'], tc_entry['format'], tc_entry['text_type'], tc_entry['content'])) # Filter the TC records to exclude filtered_tc_records = [ rec for rec in tc_records if rec[:-1] not in existing_tc_records ] logger.info("Finished filtering the text content.") return list(set(filtered_tc_records)), flawed_tcs
def main(): arg_parser = get_parser() args = arg_parser.parse_args() s3 = boto3.client('s3') s3_log_prefix = get_s3_job_log_prefix(args.s3_base, args.job_name) logger.info("Using log prefix \"%s\"" % s3_log_prefix) id_list_key = args.s3_base + 'id_list' logger.info("Looking for id list on s3 at \"%s\"" % id_list_key) try: id_list_obj = s3.get_object(Bucket=bucket_name, Key=id_list_key) except botocore.exceptions.ClientError as e: # Handle a missing object gracefully if e.response['Error']['Code'] == 'NoSuchKey': logger.info('Could not find PMID list file at %s, exiting' % id_list_key) sys.exit(1) # If there was some other kind of problem, re-raise the exception else: raise e # Get the content from the object id_list_str = id_list_obj['Body'].read().decode('utf8').strip() id_str_list = id_list_str.splitlines()[args.start_index:args.end_index] random.shuffle(id_str_list) tcids = [int(line.strip()) for line in id_str_list] # Get the reader objects if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) kwargs = {'base_dir': args.out_dir, 'n_proc': args.num_cores} readers = construct_readers(args.readers, **kwargs) # Record the reader versions used in this run. reader_versions = {} for reader in readers: reader_versions[reader.name] = reader.get_version() s3.put_object(Bucket=bucket_name, Key=get_s3_reader_version_loc(args.s3_base, args.job_name), Body=json.dumps(reader_versions)) # Some combinations of options don't make sense: forbidden_combos = [('all', 'unread'), ('none', 'unread'), ('none', 'none')] assert (args.read_mode, args.rslt_mode) not in forbidden_combos, \ ("The combination of reading mode %s and statement mode %s is not " "allowed." % (args.reading_mode, args.rslt_mode)) # Get a handle for the database if args.test: from indra_db.tests.util import get_temp_db db = get_temp_db(clear=True) else: db = None # Read everything ======================================== if args.batch is None: run_reading(readers, tcids, verbose=True, db=db, reading_mode=args.read_mode, rslt_mode=args.rslt_mode) else: for tcid_batch in batch_iter(tcids, args.batch): run_reading(readers, tcid_batch, verbose=True, db=db, reading_mode=args.read_mode, rslt_mode=args.rslt_mode) # Preserve the sparser logs contents = os.listdir('.') logger.info("Checking for any log files to cache:\n" + '\n'.join(contents)) sparser_logs = [] trips_logs = [] for fname in contents: # Check if this file is a sparser log if fname.startswith('sparser') and fname.endswith('log'): sparser_logs.append(fname) elif is_trips_datestring(fname): for sub_fname in os.listdir(fname): if sub_fname.endswith('.log') or sub_fname.endswith('.err'): trips_logs.append(os.path.join(fname, sub_fname)) _dump_logs_to_s3(s3, s3_log_prefix, 'sparser', sparser_logs) _dump_logs_to_s3(s3, s3_log_prefix, 'trips', trips_logs) return
def test_iterator_slicing(): size = 50 a = _gen_sym_df(size) pairs = set() n = 0 for n in range(size): k = 0 row, col = _get_off_diag_pair(size) while (row, col) in pairs: row, col = _get_off_diag_pair(size) k += 1 if k > 1000: print('Too many while iterations, breaking') break if k > 1000: break a.iloc[row, col] = np.nan a.iloc[col, row] = np.nan pairs.add((row, col)) pairs.add((col, row)) pairs_removed = n + 1 # Assert that we're correct so far assert (size**2 - size - 2 * pairs_removed) / 2 == get_pairs(a) # Check that the iterator slicing for multiprocessing runs through all # the pairs # Get total pairs available total_pairs = get_pairs(a) # Chunks wanted chunks_wanted = 10 chunksize = get_chunk_size(chunks_wanted, total_pairs) chunk_iter = batch_iter(iterator=corr_matrix_to_generator(a), batch_size=chunksize, return_func=list) pair_count = 0 chunk_ix = 0 for chunk_ix, list_of_pairs in enumerate(chunk_iter): pair_count += len([t for t in list_of_pairs if t is not None]) # Were all pairs looped? assert pair_count == total_pairs, \ f'pair_count={pair_count} total_pairs={total_pairs}' # Does the number of loop iterations correspond to the number of chunks # wanted? assert chunk_ix + 1 == chunks_wanted, \ f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}' # Redo the same with subset of names name_subset = list( np.random.choice(a.columns.values, size=size // 3, replace=False)) # Add a name that does not exist in the original df name_subset.append(size + 2) # Get total pairs available total_pairs_permute = get_pairs(a, subset_list=name_subset) # Chunks wanted chunks_wanted = 10 chunksize = get_chunk_size(chunks_wanted, total_pairs_permute) chunk_iter = batch_iter(iterator=corr_matrix_to_generator( a, subset_list=name_subset), batch_size=chunksize, return_func=list) pair_count = 0 chunk_ix = 0 for chunk_ix, list_of_pairs in enumerate(chunk_iter): pair_count += len([t for t in list_of_pairs if t is not None]) # Were all pairs looped? assert pair_count == total_pairs_permute, \ f'pair_count={pair_count} total_pairs={total_pairs_permute}' # Does the number of loop iterations correspond to the number of chunks # wanted? assert chunk_ix + 1 == chunks_wanted, \ f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}'
def reground_texts(texts, ont_yml, webservice, topk=10, is_canonicalized=False, filter=True, cache_path=None): """Ground concept texts given an ontology with an Eidos web service. Parameters ---------- texts : list[str] A list of concept texts to ground. ont_yml : str A serialized YAML string representing the ontology. webservice : str The address where the Eidos web service is running, e.g., http://localhost:9000. topk : Optional[int] The number of top scoring groundings to return. Default: 10 is_canonicalized : Optional[bool] If True, the texts are assumed to be canonicalized. If False, Eidos will canonicalize the texts which yields much better groundings but is slower. Default: False filter : Optional[bool] If True, Eidos filters the ontology to remove determiners from examples and other similar operations. Should typically be set to True. Default: True Returns ------- dict A JSON dict of the results from the Eidos webservice. """ all_results = [] grounding_cache = {} if cache_path: if os.path.exists(cache_path): with open(cache_path, 'rb') as fh: grounding_cache = pickle.load(fh) logger.info('Loaded %d groundings from cache' % len(grounding_cache)) texts_to_ground = list(set(texts) - set(grounding_cache.keys())) logger.info('Grounding a total of %d texts' % len(texts_to_ground)) for text_batch in tqdm.tqdm(batch_iter(texts_to_ground, batch_size=500, return_func=list), total=math.ceil(len(texts_to_ground) / 500)): params = { 'ontologyYaml': ont_yml, 'texts': text_batch, 'topk': topk, 'isAlreadyCanonicalized': is_canonicalized, 'filter': filter } res = requests.post('%s/reground' % webservice, json=params) res.raise_for_status() grounding_for_texts = grounding_dict_to_list(res.json()) for txt, grounding in zip(text_batch, grounding_for_texts): grounding_cache[txt] = grounding all_results = [grounding_cache[txt] for txt in texts] if cache_path: with open(cache_path, 'wb') as fh: pickle.dump(grounding_cache, fh) return all_results
def get_statements(clauses, count=1000, do_stmt_count=False, db=None, preassembled=True, with_support=False, fix_refs=True, with_evidence=True): """Select statements according to a given set of clauses. Parameters ---------- clauses : list list of sqlalchemy WHERE clauses to pass to the filter query. count : int Number of statements to retrieve and process in each batch. do_stmt_count : bool Whether or not to perform an initial statement counting step to give more meaningful progress messages. db : :py:class:`DatabaseManager` Optionally specify a database manager that attaches to something besides the primary database, for example a local database instance. preassembled : bool If true, statements will be selected from the table of pre-assembled statements. Otherwise, they will be selected from the raw statements. Default is True. with_support : bool Choose whether to populate the supports and supported_by list attributes of the Statement objects. General results in slower queries. with_evidence : bool Choose whether or not to populate the evidence list attribute of the Statements. As with `with_support`, setting this to True will take longer. fix_refs : bool The paper refs within the evidence objects are not populated in the database, and thus must be filled using the relations in the database. If True (default), the `pmid` field of each Statement Evidence object is set to the correct PMIDs, or None if no PMID is available. If False, the `pmid` field defaults to the value populated by the reading system. Returns ------- list of Statements from the database corresponding to the query. """ warnings.warn(('This module is being taken out of service, as the tools ' 'have become deprecated. Moreover, the service has been ' 're-implemented to use newer tools as best as possible, ' 'but some results may be subtly different.'), DeprecationWarning) cnt = count if db is None: db = get_primary_db() stmts_tblname = 'pa_statements' if preassembled else 'raw_statements' if not preassembled: stmts = [] q = db.filter_query(stmts_tblname, *clauses) if do_stmt_count: logger.info("Counting statements...") num_stmts = q.count() logger.info("Total of %d statements" % num_stmts) db_stmts = q.yield_per(cnt) for subset in batch_iter(db_stmts, cnt): stmts.extend( get_raw_stmts_frm_db_list(db, subset, with_sids=False, fix_refs=fix_refs)) if do_stmt_count: logger.info("%d of %d statements" % (len(stmts), num_stmts)) else: logger.info("%d statements" % len(stmts)) else: logger.info("Getting preassembled statements.") if with_evidence: logger.info("Getting preassembled statements.") # Get pairs of pa statements with their linked raw statements clauses += [ db.PAStatements.mk_hash == db.RawUniqueLinks.pa_stmt_mk_hash, db.RawStatements.id == db.RawUniqueLinks.raw_stmt_id ] pa_raw_stmt_pairs = \ db.select_all([db.PAStatements, db.RawStatements], *clauses, yield_per=cnt) stmt_dict = _process_pa_statement_res_wev(db, pa_raw_stmt_pairs, count=cnt, fix_refs=fix_refs) else: # Get just pa statements without their supporting raw statement(s). pa_stmts = db.select_all(db.PAStatements, *clauses, yield_per=cnt) stmt_dict = _process_pa_statement_res_nev(pa_stmts, count=cnt) # Populate the supports/supported by fields. if with_support: get_support(stmt_dict, db=db) stmts = list(stmt_dict.values()) logger.info("In all, there are %d pa statements." % len(stmts)) return stmts
def get_statements(clauses, count=1000, do_stmt_count=False, db=None, preassembled=True, with_support=False, fix_refs=True, with_evidence=True): """Select statements according to a given set of clauses. Parameters ---------- clauses : list list of sqlalchemy WHERE clauses to pass to the filter query. count : int Number of statements to retrieve and process in each batch. do_stmt_count : bool Whether or not to perform an initial statement counting step to give more meaningful progress messages. db : :py:class:`DatabaseManager` Optionally specify a database manager that attaches to something besides the primary database, for example a local database instance. preassembled : bool If true, statements will be selected from the table of pre-assembled statements. Otherwise, they will be selected from the raw statements. Default is True. with_support : bool Choose whether to populate the supports and supported_by list attributes of the Statement objects. General results in slower queries. with_evidence : bool Choose whether or not to populate the evidence list attribute of the Statements. As with `with_support`, setting this to True will take longer. fix_refs : bool The paper refs within the evidence objects are not populated in the database, and thus must be filled using the relations in the database. If True (default), the `pmid` field of each Statement Evidence object is set to the correct PMIDs, or None if no PMID is available. If False, the `pmid` field defaults to the value populated by the reading system. Returns ------- list of Statements from the database corresponding to the query. """ if db is None: db = get_primary_db() stmts_tblname = 'pa_statements' if preassembled else 'raw_statements' if not preassembled: stmts = [] q = db.filter_query(stmts_tblname, *clauses) if do_stmt_count: logger.info("Counting statements...") num_stmts = q.count() logger.info("Total of %d statements" % num_stmts) db_stmts = q.yield_per(count) for subset in batch_iter(db_stmts, count): stmts.extend(get_raw_stmts_frm_db_list(db, subset, with_sids=False, fix_refs=fix_refs)) if do_stmt_count: logger.info("%d of %d statements" % (len(stmts), num_stmts)) else: logger.info("%d statements" % len(stmts)) else: logger.info("Getting preassembled statements.") if with_evidence: logger.info("Getting preassembled statements.") # Get pairs of pa statements with their linked raw statements clauses += db.join(db.PAStatements, db.RawStatements) pa_raw_stmt_pairs = \ db.select_all([db.PAStatements, db.RawStatements], *clauses, yield_per=count) # Iterate over the batches to create the statement objects. stmt_dict = {} ev_dict = {} raw_stmt_dict = {} total_ev = 0 for stmt_pair_batch in batch_iter(pa_raw_stmt_pairs, count): # Instantiate the PA statement objects, and record the uuid # evidence (raw statement) links. raw_stmt_objs = [] for pa_stmt_db_obj, raw_stmt_db_obj in stmt_pair_batch: k = pa_stmt_db_obj.mk_hash if k not in stmt_dict.keys(): stmt_dict[k] = _get_statement_object(pa_stmt_db_obj) ev_dict[k] = [raw_stmt_db_obj.id,] else: ev_dict[k].append(raw_stmt_db_obj.id) raw_stmt_objs.append(raw_stmt_db_obj) total_ev += 1 logger.info("Up to %d pa statements, with %d pieces of " "evidence in all." % (len(stmt_dict), total_ev)) # Instantiate the raw statements. raw_stmt_sid_tpls = get_raw_stmts_frm_db_list(db, raw_stmt_objs, fix_refs, with_sids=True) raw_stmt_dict.update({sid: s for sid, s in raw_stmt_sid_tpls}) logger.info("Processed %d raw statements." % len(raw_stmt_sid_tpls)) # Attach the evidence logger.info("Inserting evidence.") for k, sid_list in ev_dict.items(): stmt_dict[k].evidence = [raw_stmt_dict[sid].evidence[0] for sid in sid_list] else: # Get just pa statements without their supporting raw statement(s). pa_stmts = db.select_all(db.PAStatements, *clauses, yield_per=count) # Iterate over the batches to create the statement objects. stmt_dict = {} for stmt_pair_batch in batch_iter(pa_stmts, count): # Instantiate the PA statement objects. for pa_stmt_db_obj in stmt_pair_batch: k = pa_stmt_db_obj.mk_hash if k not in stmt_dict.keys(): stmt_dict[k] = _get_statement_object(pa_stmt_db_obj) logger.info("Up to %d pa statements in all." % len(stmt_dict)) # Populate the supports/supported by fields. if with_support: logger.info("Populating support links.") support_links = db.select_all( [db.PASupportLinks.supported_mk_hash, db.PASupportLinks.supporting_mk_hash], or_(db.PASupportLinks.supported_mk_hash.in_(stmt_dict.keys()), db.PASupportLinks.supporting_mk_hash.in_(stmt_dict.keys())) ) for supped_hash, supping_hash in set(support_links): if supped_hash == supping_hash: assert False, 'Self-support found on-load.' supped_stmt = stmt_dict.get(supped_hash, Unresolved(shallow_hash=supped_hash)) supping_stmt = stmt_dict.get(supping_hash, Unresolved(shallow_hash=supping_hash)) supped_stmt.supported_by.append(supping_stmt) supping_stmt.supports.append(supped_stmt) stmts = list(stmt_dict.values()) logger.info("In all, there are %d pa statements." % len(stmts)) return stmts
def test_iterator_slicing(): size = 50 a, pairs_removed = _get_df_w_nan(size) # Assert that we're correct so far # Get total pairs available: total_pairs = get_pairs(a) # all items - diagonal - all removed items off diagonal assert (size**2 - size - 2*pairs_removed) / 2 == total_pairs # Check that the iterator slicing for multiprocessing runs through all # the pairs # Chunks wanted chunks_wanted = 10 chunksize = get_chunk_size(chunks_wanted, total_pairs) chunk_iter = batch_iter(iterator=corr_matrix_to_generator(a), batch_size=chunksize, return_func=list) pair_count = 0 chunk_ix = 0 for chunk_ix, list_of_pairs in enumerate(chunk_iter): pair_count += len([(t[0][0], t[0][1], t[1]) for t in list_of_pairs if t is not None]) # Were all pairs looped? assert pair_count == total_pairs, \ f'pair_count={pair_count} total_pairs={total_pairs}' # Does the number of loop iterations correspond to the number of chunks # wanted? assert chunk_ix + 1 == chunks_wanted, \ f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}' # Redo the same with subset of names name_subset = list(np.random.choice(a.columns.values, size=size // 3, replace=False)) # Add a name that does not exist in the original df name_subset.append(size+2) # Get total pairs available total_pairs_permute = get_pairs(a, subset_list=name_subset) # Chunks wanted chunks_wanted = 10 chunksize = get_chunk_size(chunks_wanted, total_pairs_permute) chunk_iter = batch_iter( iterator=corr_matrix_to_generator(a, subset_list=name_subset), batch_size=chunksize, return_func=list ) pair_count = 0 chunk_ix = 0 for chunk_ix, list_of_pairs in enumerate(chunk_iter): pair_count += len([(t[0][0], t[0][1], t[1]) for t in list_of_pairs if t is not None]) # Were all pairs looped? assert pair_count == total_pairs_permute, \ f'pair_count={pair_count} total_pairs={total_pairs_permute}' # Does the number of loop iterations correspond to the number of chunks # wanted? assert chunk_ix + 1 == chunks_wanted, \ f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}'