예제 #1
0
    def _pa_batch_iter(self, db, in_mks=None, ex_mks=None):
        """Return an iterator over batches of preassembled statements.

        This avoids the need to load all such statements from the database into
        RAM at the same time (as this can be quite large).

        You may limit the set of pa_statements loaded by providing a set/list of
        matches-keys of the statements you wish to include.
        """
        if in_mks is None and ex_mks is None:
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         yield_per=self.batch_size)
        elif ex_mks is None and in_mks:
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         db.PAStatements.mk_hash.in_(in_mks),
                                         yield_per=self.batch_size)
        elif in_mks is None and ex_mks:
            db_stmt_iter = db.select_all(
                db.PAStatements.json,
                db.PAStatements.mk_hash.notin_(ex_mks),
                yield_per=self.batch_size)
        elif in_mks and ex_mks:
            db_stmt_iter = db.select_all(
                db.PAStatements.json,
                db.PAStatements.mk_hash.notin_(ex_mks),
                db.PAStatements.mk_hash.in_(in_mks),
                yield_per=self.batch_size)
        else:
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         yield_per=self.batch_size)

        pa_stmts = (_stmt_from_json(s_json) for s_json, in db_stmt_iter)
        return batch_iter(pa_stmts, self.batch_size, return_func=list)
예제 #2
0
    def _raw_sid_stmt_iter(self, db, id_set, do_enumerate=False):
        """Return a generator over statements with the given database ids."""
        def _fixed_raw_stmt_from_json(s_json, tr):
            stmt = _stmt_from_json(s_json)
            if tr is not None:
                stmt.evidence[0].pmid = tr.pmid
                stmt.evidence[0].text_refs = {
                    k: v
                    for k, v in tr.__dict__.items() if not k.startswith('_')
                }
            return stmt

        i = 0
        for stmt_id_batch in batch_iter(id_set, self.batch_size):
            subres = (db.filter_query(
                [db.RawStatements.id, db.RawStatements.json, db.TextRef],
                db.RawStatements.id.in_(stmt_id_batch)).outerjoin(
                    db.Reading).outerjoin(db.TextContent).outerjoin(
                        db.TextRef).yield_per(self.batch_size // 10))
            data = [(sid, _fixed_raw_stmt_from_json(s_json, tr))
                    for sid, s_json, tr in subres]
            if do_enumerate:
                yield i, data
                i += 1
            else:
                yield data
예제 #3
0
    def _get_unique_statements(self, db, stmt_tpls, num_stmts, mk_done=None):
        """Get the unique Statements from the raw statements."""
        if mk_done is None:
            mk_done = set()

        new_mk_set = set()
        stmt_batches = batch_iter(stmt_tpls, self.batch_size, return_func=list)
        num_batches = num_stmts / self.batch_size
        for i, stmt_tpl_batch in enumerate(stmt_batches):
            self._log("Processing batch %d/%d of %d/%d statements." %
                      (i, num_batches, len(stmt_tpl_batch), num_stmts))
            unique_stmts, evidence_links = \
                self._make_unique_statement_set(stmt_tpl_batch)
            new_unique_stmts = []
            for s in unique_stmts:
                s_hash = s.get_hash(shallow=True)
                if s_hash not in (mk_done | new_mk_set):
                    new_mk_set.add(s_hash)
                    new_unique_stmts.append(s)
            insert_pa_stmts(db, new_unique_stmts)
            db.copy('raw_unique_links', evidence_links,
                    ('pa_stmt_mk_hash', 'raw_stmt_id'))
        self._log("Added %d new pa statements into the database." %
                  len(new_mk_set))
        return new_mk_set
예제 #4
0
def get_unique_text_refs():
    """Get unique INDRA DB TextRef IDs for all identifiers in CORD19.

    Queries TextRef IDs with PMIDs, PMCIDs, and DOIs from CORD19, then
    deduplicates to obtain a unique set of TextRefs.

    Returns
    -------
    set of ints
        Unique TextRef IDs.
    """
    pmcids = get_ids('pmcid')
    pmids = [fix_pmid(pmid) for pmid in get_ids('pubmed_id')]
    dois = [fix_doi(doi) for doi in get_ids('doi')]
    # Get unique text_refs from the DB
    db = get_primary_db()
    print("Getting TextRefs by PMCID")
    tr_pmcids = db.select_all(db.TextRef.id, db.TextRef.pmcid_in(pmcids))
    print("Getting TextRefs by PMID")
    tr_pmids = db.select_all(db.TextRef.id, db.TextRef.pmid_in(pmids))
    tr_dois = []
    for ix, doi_batch in enumerate(batch_iter(dois, 10000)):
        print("Getting Text Refs by DOI batch", ix)
        tr_doi_batch = db.select_all(
            db.TextRef.id, db.TextRef.doi_in(doi_batch, filter_ids=True))
        tr_dois.extend(tr_doi_batch)
    ids = set([
        res.id for res_list in (tr_dois, tr_pmcids, tr_pmids)
        for res in res_list
    ])
    print(len(ids), "unique TextRefs in DB")
    trs = db.select_all(db.TextRef, db.TextRef.id.in_(ids))
    return trs
예제 #5
0
def get_indradb_pa_stmts():
    """Get preassembled INDRA Stmts for PMC articles from INDRA DB.

    DEPRECATED. Get Raw Statements instead.
    """
    # Get the list of all PMCIDs from the corpus metadata
    pmcids = get_ids('pmcid')
    paper_refs = [('pmcid', p) for p in pmcids]
    stmt_jsons = []
    batch_size = 1000
    start = time.time()
    for batch_ix, paper_batch in enumerate(batch_iter(paper_refs, batch_size)):
        if batch_ix <= 5:
            continue
        papers = list(paper_batch)
        print("Querying DB for statements for %d papers" % batch_size)
        batch_start = time.time()
        result = get_statement_jsons_from_papers(papers)
        batch_elapsed = time.time() - batch_start
        batch_jsons = [
            stmt_json for stmt_hash, stmt_json in result['statements'].items()
        ]
        print("Returned %d stmts in %f sec" %
              (len(batch_jsons), batch_elapsed))
        batch_stmts = stmts_from_json(batch_jsons)
        ac.dump_statements(batch_stmts, 'batch_%02d.pkl' % batch_ix)
        stmt_jsons += batch_jsons
    elapsed = time.time() - start
    print("Total time: %f sec, %d papers" % (elapsed, len(paper_refs)))
    stmts = stmts_from_json(stmt_jsons)
    ac.dump_statements(stmts, 'cord19_pmc_stmts.pkl')
    return stmt_jsons
예제 #6
0
    def populate(self, db):
        # Turn the list of dicts into a set of tuples
        tr_data_set = {
            tuple([entry[id_type] for id_type in self.tr_cols])
            for entry in self.tr_data
        }
        # Filter_text_refs will figure out which articles are already in the
        # TextRef table and will update them with any new metadata;
        # filtered_tr_records are the ones that need to be added to the DB
        filtered_tr_records = set()
        flawed_tr_records = set()
        for ix, tr_batch in enumerate(batch_iter(tr_data_set, 10000)):
            print("Getting Text Refs using pmid/pmcid/doi, batch", ix)
            filt_batch, flaw_batch = \
                    self.filter_text_refs(db, set(tr_batch),
                                    primary_id_types=['pmid', 'pmcid', 'doi'])
            filtered_tr_records |= set(filt_batch)
            flawed_tr_records |= set(flaw_batch)
        trs_to_skip = {rec for cause, rec in flawed_tr_records}
        # Why did the original version not skip in case of disagreeing
        # pmid or doi?
        #pmcids_to_skip = {rec[self.tr_cols.index('pmcid')]
        #                  for cause, rec in flawed_tr_records
        #                  if cause in ['pmcid', 'over_match_input',
        #                               'over_match_db']}

        # Then we put together the updated text content data
        if len(trs_to_skip) is not 0:
            mod_tc_data = [
                tc for tc in self.tc_data if (tc.get('pmid'), tc.get('pmcid'),
                                              tc.get('doi')) not in trs_to_skip
            ]
        else:
            mod_tc_data = self.tc_data

        # Upload TextRef data for articles NOT already in the DB
        logger.info('Adding %d new text refs...' % len(filtered_tr_records))
        if filtered_tr_records:
            self.copy_into_db(db, 'text_ref', filtered_tr_records,
                              self.tr_cols)
        gatherer.add('refs', len(filtered_tr_records))

        # Process the text content data
        filtered_tc_records, flawed_tcs = \
                            self.filter_text_content(db, mod_tc_data)

        # Upload the text content data.
        logger.info('Adding %d more text content entries...' %
                    len(filtered_tc_records))
        self.copy_into_db(db, 'text_content', filtered_tc_records,
                          self.tc_cols)
        gatherer.add('content', len(filtered_tc_records))
        return {
            'filtered_tr_records': filtered_tr_records,
            'flawed_tr_records': flawed_tr_records,
            'mod_tc_data': mod_tc_data,
            'filtered_tc_records': filtered_tc_records
        }
예제 #7
0
def download_statements(hashes):
    """Download the INDRA Statements corresponding to a set of hashes.
    """
    stmts_by_hash = {}
    for group in tqdm.tqdm(batch_iter(hashes, 200),
                           total=int(len(hashes) / 200)):
        idbp = indra_db_rest.get_statements_by_hash(list(group), ev_limit=10)
        for stmt in idbp.statements:
            stmts_by_hash[stmt.get_hash()] = stmt
    return stmts_by_hash
예제 #8
0
def download_statements(df, ev_limit=5):
    """Download the INDRA Statements corresponding to entries in a data frame.
    """
    all_stmts = []
    for idx, group in enumerate(batch_iter(df.hash, 500)):
        logger.info('Getting statement batch %d' % idx)
        idbp = indra_db_rest.get_statements_by_hash(list(group),
                                                    ev_limit=ev_limit)
        all_stmts += idbp.statements
    return all_stmts
예제 #9
0
def get_raw_statements_for_pmids(pmids, mode='all', batch_size=100):
    """Return EmmaaStatements based on extractions from given PMIDs.

    Parameters
    ----------
    pmids : set or list of str
        A set of PMIDs to find raw INDRA Statements for in the INDRA DB.
    mode : 'all' or 'distilled'
        The 'distilled' mode makes sure that the "best", non-redundant
        set of raw statements are found across potentially redundant text
        contents and reader versions. The 'all' mode doesn't do such
        distillation but is significantly faster.
    batch_size : Optional[int]
        Determines how many PMIDs to fetch statements for in each
        iteration. Default: 100.

    Returns
    -------
    dict
        A dict keyed by PMID with values INDRA Statements obtained
        from the given PMID.
    """
    db = get_db('primary')
    logger.info(f'Getting raw statements for {len(pmids)} PMIDs')
    all_stmts = defaultdict(list)
    for pmid_batch in tqdm.tqdm(batch_iter(pmids,
                                           return_func=set,
                                           batch_size=batch_size),
                                total=len(pmids) / batch_size):
        if mode == 'distilled':
            clauses = [
                db.TextRef.pmid.in_(pmid_batch),
                db.TextContent.text_ref_id == db.TextRef.id,
                db.Reading.text_content_id == db.TextContent.id,
                db.RawStatements.reading_id == db.Reading.id
            ]
            distilled_stmts = distill_stmts(db,
                                            get_full_stmts=True,
                                            clauses=clauses)
            for stmt in distilled_stmts:
                all_stmts[stmt.evidence[0].pmid].append(stmt)
        else:
            id_stmts = \
                get_raw_stmt_jsons_from_papers(pmid_batch, id_type='pmid',
                                               db=db)
            for pmid, stmt_jsons in id_stmts.items():
                all_stmts[pmid] += stmts_from_json(stmt_jsons)
    all_stmts = dict(all_stmts)
    return all_stmts
예제 #10
0
 def _raw_sid_stmt_iter(self, db, id_set, do_enumerate=False):
     """Return a generator over statements with the given database ids."""
     i = 0
     for stmt_id_batch in batch_iter(id_set, self.batch_size):
         subres = db.select_all(
             [db.RawStatements.id, db.RawStatements.json],
             db.RawStatements.id.in_(stmt_id_batch),
             yield_per=self.batch_size // 10)
         if do_enumerate:
             yield i, [(sid, _stmt_from_json(s_json))
                       for sid, s_json in subres]
             i += 1
         else:
             yield [(sid, _stmt_from_json(s_json))
                    for sid, s_json in subres]
예제 #11
0
def _process_pa_statement_res_wev(db,
                                  stmt_iterable,
                                  count=1000,
                                  fix_refs=True):
    warnings.warn(('This module is being taken out of service, as the tools '
                   'have become deprecated. Moreover, the service has been '
                   're-implemented to use newer tools as best as possible, '
                   'but some results may be subtly different.'),
                  DeprecationWarning)
    # Iterate over the batches to create the statement objects.
    stmt_dict = {}
    ev_dict = {}
    raw_stmt_dict = {}
    total_ev = 0
    for stmt_pair_batch in batch_iter(stmt_iterable, count):
        # Instantiate the PA statement objects, and record the uuid
        # evidence (raw statement) links.
        raw_stmt_objs = []
        for pa_stmt_db_obj, raw_stmt_db_obj in stmt_pair_batch:
            k = pa_stmt_db_obj.mk_hash
            if k not in stmt_dict.keys():
                stmt_dict[k] = get_statement_object(pa_stmt_db_obj)
                ev_dict[k] = [
                    raw_stmt_db_obj.id,
                ]
            else:
                ev_dict[k].append(raw_stmt_db_obj.id)
            raw_stmt_objs.append(raw_stmt_db_obj)
            total_ev += 1

        logger.info("Up to %d pa statements, with %d pieces of "
                    "evidence in all." % (len(stmt_dict), total_ev))

        # Instantiate the raw statements.
        raw_stmt_sid_tpls = get_raw_stmts_frm_db_list(db,
                                                      raw_stmt_objs,
                                                      fix_refs,
                                                      with_sids=True)
        raw_stmt_dict.update({sid: s for sid, s in raw_stmt_sid_tpls})
        logger.info("Processed %d raw statements." % len(raw_stmt_sid_tpls))

    # Attach the evidence
    logger.info("Inserting evidence.")
    for k, sid_list in ev_dict.items():
        stmt_dict[k].evidence = [
            raw_stmt_dict[sid].evidence[0] for sid in sid_list
        ]
    return stmt_dict
예제 #12
0
    def get_mutated_genes(self):
        """Return dict of gene mutation frequencies based on TCGA studies."""
        if self.mutation_cache:
            logger.info('Loading mutations from %s' % self.mutation_cache)
            with open(self.mutation_cache, 'r') as fh:
                self.mutations = json.load(fh)
        else:
            logger.info('Getting mutations from cBio web service')
            mutations = {}
            for tcga_study_name in tcga_studies[self.tcga_study_prefix]:
                for idx, hgnc_name_batch in \
                                enumerate(batch_iter(hgnc_ids.keys(), 200)):
                    logger.info('Fetching mutations for %s and gene batch %s' %
                                (tcga_study_name, idx))
                    patient_mutations = \
                        cbio_client.get_profile_data(tcga_study_name,
                                                     hgnc_name_batch,
                                                     'mutation')
                    # e.g. 'ICGC_0002_TD': {'BRAF': None, 'KRAS': 'G12D'}
                    for patient, gene_mut_dict in patient_mutations.items():
                        # 'BRAF': None
                        for gene, mutated in gene_mut_dict.items():
                            if mutated is not None:
                                try:
                                    mutations[gene] += 1
                                except KeyError:
                                    mutations[gene] = 1
            self.mutations = mutations

        # Normalize mutations by length
        self.norm_mutations = {}
        for gene_name, num_muts in self.mutations.items():
            hgnc_id = get_hgnc_id(gene_name)
            up_id = get_uniprot_id(hgnc_id)
            if not up_id:
                logger.warning("Could not get Uniprot ID for HGNC symbol %s "
                               "with HGNC ID %s" % (gene_name, hgnc_id))
                length = 500  # a guess at a default
            else:
                length = uniprot_client.get_length(up_id)
                if not length:
                    logger.warning("Could not get length for Uniprot "
                                   "ID %s" % up_id)
                    length = 500  # a guess at a default
            self.norm_mutations[gene_name] = num_muts / float(length)

        return self.mutations, self.norm_mutations
예제 #13
0
def _process_pa_statement_res_nev(stmt_iterable, count=1000):
    warnings.warn(('This module is being taken out of service, as the tools '
                   'have become deprecated. Moreover, the service has been '
                   're-implemented to use newer tools as best as possible, '
                   'but some results may be subtly different.'),
                  DeprecationWarning)
    # Iterate over the batches to create the statement objects.
    stmt_dict = {}
    for stmt_pair_batch in batch_iter(stmt_iterable, count):
        # Instantiate the PA statement objects.
        for pa_stmt_db_obj in stmt_pair_batch:
            k = pa_stmt_db_obj.mk_hash
            if k not in stmt_dict.keys():
                stmt_dict[k] = get_statement_object(pa_stmt_db_obj)

        logger.info("Up to %d pa statements in all." % len(stmt_dict))
    return stmt_dict
예제 #14
0
def download_statements(df, beliefs, ev_limit=5):
    """Download the INDRA Statements corresponding to entries in a data frame.
    """
    all_stmts = []
    unique_hashes = list(set(df.stmt_hash))
    batches = list(batch_iter(unique_hashes, 500))
    logger.info('Getting %d unique hashes from db' % len(unique_hashes))
    for group in tqdm.tqdm(batches):
        idbp = indra_db_rest.get_statements_by_hash(list(group),
                                                    ev_limit=ev_limit)
        all_stmts += idbp.statements
    for stmt in all_stmts:
        belief = beliefs.get(stmt.get_hash())
        if belief is None:
            logger.info('No belief found for %s' % str(stmt))
            continue
        stmt.belief = belief
    return all_stmts
예제 #15
0
    def select_all_batched(self, batch_size, tbls, *args, skip_idx=None,
                           order_by=None):
        """Load the results of a query in batches of size batch_size.

        Note that this differs from using yeild_per in that the results are not
        returned as a single iterable, but as an iterator of iterables.

        Note also that the order of results, and thus the contents of offsets,
        may vary for large queries unless an explicit order_by clause is added
        to the query.
        """
        q = self.filter_query(tbls, *args)
        if order_by:
            q = q.order_by(order_by)
        res_iter = q.yield_per(batch_size)
        for i, batch in enumerate(batch_iter(res_iter, batch_size)):
            if i != skip_idx:
                yield i, batch
예제 #16
0
def get_stmts_pmids_mesh(subject, stmt_type, object_list):
    stmts = []
    for obj in object_list:
        idrp = idr.get_statements(subject=subject,
                                  object=obj,
                                  stmt_type=stmt_type,
                                  ev_limit=10000)
        stmts += idrp.statements

    # Collect the PMIDs for the stmts
    pmids = [e.pmid for s in stmts for e in s.evidence]

    mesh_terms = []
    for batch in batch_iter(pmids, 200):
        pmid_list = list(batch)
        print("Retrieving metadata for %d articles" % len(pmid_list))
        metadata = get_metadata_for_ids(pmid_list)
        for pmid, pmid_meta in metadata.items():
            mesh_terms += [d['mesh'] for d in pmid_meta['mesh_annotations']]
    return (stmts, pmids, mesh_terms)
예제 #17
0
    def get_mutated_genes(self):
        """Return dict of gene mutation frequencies based on TCGA studies."""
        if self.mutation_cache:
            logger.info('Loading mutations from %s' % self.mutation_cache)
            with open(self.mutation_cache, 'r') as fh:
                self.mutations = json.load(fh)
        else:
            logger.info('Getting mutations from cBio web service')
            mutations = {}
            for tcga_study_name in tcga_studies[self.tcga_study_prefix]:
                for idx, hgnc_name_batch in \
                                enumerate(batch_iter(hgnc_ids.keys(), 200)):
                    logger.info('Fetching mutations for %s and gene batch %s' %
                                (tcga_study_name, idx))
                    patient_mutations = \
                        cbio_client.get_profile_data(tcga_study_name,
                                                     hgnc_name_batch,
                                                     'mutation')
                    # e.g. 'ICGC_0002_TD': {'BRAF': None, 'KRAS': 'G12D'}
                    for patient, gene_mut_dict in patient_mutations.items():
                        # 'BRAF': None
                        for gene, mutated in gene_mut_dict.items():
                            if mutated is not None:
                                try:
                                    mutations[gene] += 1
                                except KeyError:
                                    mutations[gene] = 1
            self.mutations = mutations

        # Normalize mutations by length
        self.norm_mutations = {}
        for gene_name, num_muts in self.mutations.items():
            self.norm_mutations[gene_name] = \
                self.normalize_mutation_count(gene_name, num_muts)

        return self.mutations, self.norm_mutations
예제 #18
0
def get_statements(clauses, count=1000, do_stmt_count=True, db=None,
                   preassembled=True, with_support=False):
    """Select statements according to a given set of clauses.

    Parameters
    ----------
    clauses : list
        list of sqlalchemy WHERE clauses to pass to the filter query.
    count : int
        Number of statements to retrieve and process in each batch.
    do_stmt_count : bool
        Whether or not to perform an initial statement counting step to give
        more meaningful progress messages.
    db : :py:class:`DatabaseManager`
        Optionally specify a database manager that attaches to something
        besides the primary database, for example a local database instance.
    preassembled : bool
        If true, statements will be selected from the table of pre-assembled
        statements. Otherwise, they will be selected from the raw statements.
        Default is True.

    Returns
    -------
    list of Statements from the database corresponding to the query.
    """
    if db is None:
        db = get_primary_db()

    stmts_tblname = 'pa_statements' if preassembled else 'raw_statements'

    if not preassembled:
        stmts = []
        q = db.filter_query(stmts_tblname, *clauses)
        if do_stmt_count:
            logger.info("Counting statements...")
            num_stmts = q.count()
            logger.info("Total of %d statements" % num_stmts)
        db_stmts = q.yield_per(count)
        for subset in batch_iter(db_stmts, count):
            stmts.extend(make_raw_stmts_from_db_list(db, subset))
            if do_stmt_count:
                logger.info("%d of %d statements" % (len(stmts), num_stmts))
            else:
                logger.info("%d statements" % len(stmts))
    else:
        logger.info("Getting preassembled statements.")
        # Get pairs of pa statements with their supporting statement (as long as
        # the number of supporting statements).
        clauses += db.join(db.PAStatements, db.RawStatements)
        pa_raw_stmt_pairs = db.select_all([db.PAStatements, db.RawStatements],
                                          *clauses, yield_per=count)

        # Iterate over the batches to create the statement objects.
        stmt_dict = {}
        ev_dict = {}
        raw_stmt_dict = {}
        for stmt_pair_batch in batch_iter(pa_raw_stmt_pairs, count):
            # Instantiate the PA statement objects, and record the uuid evidence
            # (raw statement) links.
            raw_stmt_obj_list = []
            for pa_stmt_db_obj, raw_stmt_db_obj in stmt_pair_batch:
                k = pa_stmt_db_obj.mk_hash
                if k not in stmt_dict.keys():
                    stmt_dict[k] = _get_statement_object(pa_stmt_db_obj)
                    ev_dict[k] = [raw_stmt_db_obj.uuid,]
                else:
                    ev_dict[k].append(raw_stmt_db_obj.uuid)
                raw_stmt_obj_list.append(raw_stmt_db_obj)

            logger.info("Up to %d pa statements, with %d pieces of evidence in "
                        "all." % (len(stmt_dict), len(ev_dict)))

            # Instantiate the raw statements.
            raw_stmts = make_raw_stmts_from_db_list(db, raw_stmt_obj_list)
            raw_stmt_dict.update({s.uuid: s for s in raw_stmts})
            logger.info("Processed %d raw statements." % len(raw_stmts))

        # Attach the evidence
        logger.info("Inserting evidence.")
        for k, uuid_list in ev_dict.items():
            stmt_dict[k].evidence = [raw_stmt_dict[uuid].evidence[0]
                                     for uuid in uuid_list]

        # Populate the supports/supported by fields.
        if with_support:
            logger.info("Populating support links.")
            support_links = db.filter_query(
                [db.PASupportLinks.supported_mk_hash,
                 db.PASupportLinks.supporting_mk_hash],
                or_(db.PASupportLinks.supported_mk_hash.in_(stmt_dict.keys()),
                    db.PASupportLinks.supporting_mk_hash.in_(stmt_dict.keys()))
                ).distinct().yield_per(count)
            for supped_hash, supping_hash in support_links:
                supped_stmt = stmt_dict.get(supped_hash,
                                            Unresolved(shallow_hash=supped_hash))
                supping_stmt = stmt_dict.get(supping_hash,
                                             Unresolved(shallow_hash=supping_hash))
                if not isinstance(supped_stmt, str):
                    supped_stmt.supported_by.append(supping_stmt)
                if not isinstance(supping_stmt, str):
                    supping_stmt.supports.append(supped_stmt)

        stmts = list(stmt_dict.values())
        logger.info("In all, there are %d pa statements." % len(stmts))

    return stmts
예제 #19
0
파일: learn.py 프로젝트: steppi/gilda
        ambigs = get_filter_ambigs(str_counts)
        all_ambigs_pmids = get_all_pmids(ambigs)

    # Construct list of ambiguities
    #print('Found a total of %d ambiguities.' % len(ambigs))

    # Learn models
    param_grid = {
        'C': [10.0],
        'max_features': [100, 1000],
        'ngram_range': [(1, 2)]
    }
    models = {}
    if args.nproc == 1:
        for idx, ambig_terms_batch in \
                enumerate(batch_iter(all_ambigs_pmids, 10)):
            pickle_name = 'gilda_ambiguities_hgnc_mesh_%d.pkl' % idx
            models = learn_batch(ambig_terms_batch)
            with open(pickle_name, 'wb') as fh:
                pickle.dump(models, fh)
    else:
        pool = Pool(args.nproc)
        fun = functools.partial(learn_model, params=param_grid)
        pkl_idx = 0
        models = {}
        for count, model in enumerate(
                pool.imap_unordered(fun, all_ambigs_pmids, chunksize=10)):
            print('#### %d ####' % count)
            if model is None:
                print('Model is None, skipping')
            else:
예제 #20
0
    def decorator(*args, **kwargs):
        tracker = LogTracker()
        start_time = datetime.now()
        logger.info("Got query for %s at %s!" % (f.__name__, start_time))

        query = request.args.copy()
        offs = query.pop('offset', None)
        ev_lim = query.pop('ev_limit', None)
        best_first_str = query.pop('best_first', 'true')
        best_first = True if best_first_str.lower() == 'true' \
                             or best_first_str else False
        do_stream_str = query.pop('stream', 'false')
        do_stream = True if do_stream_str == 'true' else False
        max_stmts = min(int(query.pop('max_stmts', MAX_STATEMENTS)),
                        MAX_STATEMENTS)
        format = query.pop('format', 'json')

        api_key = query.pop('api_key', None)
        logger.info("Running function %s after %s seconds." %
                    (f.__name__, sec_since(start_time)))
        result = f(query, offs, max_stmts, ev_lim, best_first, *args, **kwargs)
        logger.info("Finished function %s after %s seconds." %
                    (f.__name__, sec_since(start_time)))

        # Redact elsevier content for those without permission.
        if api_key is None or not _has_elsevier_auth(api_key):
            for stmt_json in result['statements'].values():
                for ev_json in stmt_json['evidence']:
                    if get_source(ev_json) == 'elsevier':
                        text = ev_json['text']
                        if len(text) > 200:
                            ev_json['text'] = text[:200] + REDACT_MESSAGE
        logger.info("Finished redacting evidence for %s after %s seconds." %
                    (f.__name__, sec_since(start_time)))
        result['offset'] = offs
        result['evidence_limit'] = ev_lim
        result['statement_limit'] = MAX_STATEMENTS
        result['statements_returned'] = len(result['statements'])

        if format == 'html':
            stmts_json = result.pop('statements')
            ev_totals = result.pop('evidence_totals')
            stmts = stmts_from_json(stmts_json.values())
            html_assembler = HtmlAssembler(stmts,
                                           result,
                                           ev_totals,
                                           title='INDRA DB REST Results',
                                           db_rest_url=request.url_root[:-1])
            content = html_assembler.make_model()
            if tracker.get_messages():
                level_stats = [
                    '%d %ss' % (n, lvl.lower())
                    for lvl, n in tracker.get_level_stats().items()
                ]
                msg = ' '.join(level_stats)
                content = html_assembler.append_warning(msg)
            mimetype = 'text/html'
        else:  # Return JSON for all other values of the format argument
            result.update(tracker.get_level_stats())
            content = json.dumps(result)
            mimetype = 'application/json'

        if do_stream:
            # Returning a generator should stream the data.
            resp_json_bts = content
            gen = batch_iter(resp_json_bts, 10000)
            resp = Response(gen, mimetype=mimetype)
        else:
            resp = Response(content, mimetype=mimetype)
        logger.info("Exiting with %d statements with %d evidence of size "
                    "%f MB after %s seconds." %
                    (result['statements_returned'], result['total_evidence'],
                     sys.getsizeof(resp.data) / 1e6, sec_since(start_time)))
        return resp
예제 #21
0
def distill_stmts(db, get_full_stmts=False, clauses=None, num_procs=1,
                  delete_duplicates=True, weed_evidence=True, batch_size=1000):
    """Get a corpus of statements from clauses and filters duplicate evidence.

    Parameters
    ----------
    db : :py:class:`DatabaseManager`
        A database manager instance to access the database.
    get_full_stmts : bool
        By default (False), only Statement ids (the primary index of Statements
        on the database) are returned. However, if set to True, serialized
        INDRA Statements will be returned. Note that this will in general be
        VERY large in memory, and therefore should be used with caution.
    clauses : None or list of sqlalchemy clauses
        By default None. Specify sqlalchemy clauses to reduce the scope of
        statements, e.g. `clauses=[db.Statements.type == 'Phosphorylation']` or
        `clauses=[db.Statements.uuid.in_([<uuids>])]`.
    num_procs : int
        Select the number of process that can be used.
    delete_duplicates : bool
        Choose whether you want to delete the statements that are found to be
        duplicates.
    weed_evidence : bool
        If True, evidence links that exist for raw statements that now have
        better alternatives will be removed. If False, such links will remain,
        which may cause problems in incremental pre-assembly.

    Returns
    -------
    stmt_ret : set
        A set of either statement ids or serialized statements, depending on
        `get_full_stmts`.
    """
    if delete_duplicates:
        logger.info("Looking for ids from existing links...")
        linked_sids = {sid for sid,
                        in db.select_all(db.RawUniqueLinks.raw_stmt_id)}
    else:
        linked_sids = set()

    # Get de-duplicated Statements, and duplicate uuids, as well as uuid of
    # Statements that have been improved upon...
    logger.info("Sorting reading statements...")
    stmt_nd = _get_reading_statement_dict(db, clauses, get_full_stmts)

    stmts, duplicate_sids, bettered_duplicate_sids = \
        _get_filtered_rdg_statements(stmt_nd, get_full_stmts, linked_sids)
    logger.info("After filtering reading: %d unique statements, %d exact "
                "duplicates, %d with results from better resources available."
                % (len(stmts), len(duplicate_sids),
                   len(bettered_duplicate_sids)))
    assert not linked_sids & duplicate_sids, linked_sids & duplicate_sids
    del stmt_nd  # This takes up a lot of memory, and is done being used.

    db_stmts, db_duplicates = \
        _get_filtered_db_statements(db, get_full_stmts, clauses, linked_sids,
                                    num_procs)
    stmts |= db_stmts
    duplicate_sids |= db_duplicates
    logger.info("After filtering database statements: %d unique, %d duplicates."
                % (len(stmts), len(duplicate_sids)))
    assert not linked_sids & duplicate_sids, linked_sids & duplicate_sids

    # Remove support links for statements that have better versions available.
    bad_link_sids = bettered_duplicate_sids & linked_sids
    if len(bad_link_sids) and weed_evidence:
        logger.info("Removing bettered evidence links...")
        rm_links = db.select_all(
            db.RawUniqueLinks,
            db.RawUniqueLinks.raw_stmt_id.in_(bad_link_sids)
            )
        db.delete_all(rm_links)

    # Delete exact duplicates
    if len(duplicate_sids) and delete_duplicates:
        logger.info("Deleting duplicates...")
        for dup_id_batch in batch_iter(duplicate_sids, batch_size, set):
            bad_stmts = db.select_all(db.RawStatements,
                                      db.RawStatements.id.in_(dup_id_batch))
            bad_sid_set = {s.id for s in bad_stmts}
            bad_agents = db.select_all(db.RawAgents,
                                       db.RawAgents.stmt_id.in_(bad_sid_set))
            logger.info("Deleting %d agents associated with redundant raw "
                        "statements." % len(bad_agents))
            db.delete_all(bad_agents)
            logger.info("Deleting %d redundant raw statements." % len(bad_stmts))
            db.delete_all(bad_stmts)

    return stmts
예제 #22
0
def match_correlations(corr_z: pd.DataFrame,
                       sd_range: Tuple[float, Union[float, None]],
                       script_settings: Dict[str, Union[str, int, float]],
                       graph_filepath: str,
                       z_corr_filepath: str,
                       apriori_explained: Optional[Dict[str, str]] = None,
                       graph_type: str = 'unsigned',
                       allowed_ns: Optional[List[str]] = None,
                       allowed_sources: Optional[List[str]] = None,
                       is_a_part_of: Optional[List[str]] = None,
                       expl_funcs: Optional[List[str]] = None,
                       reactome_filepath: Optional[str] = None,
                       indra_date: Optional[str] = None,
                       info: Optional[Dict[str, Any]] = None,
                       depmap_date: Optional[str] = None,
                       n_chunks: Optional[int] = 8,
                       immediate_only: Optional[bool] = False,
                       return_unexplained: Optional[bool] = False,
                       reactome_dict: Optional[Dict[str, Any]] = None,
                       subset_list: Optional[List[Union[str, int]]] = None):
    """The main loop for matching correlations with INDRA explanations

    Note that indranet is assumed to be a global variable that needs to be
    set outside of this function and be set to global

    Parameters
    ----------
    corr_z : pd.DataFrame
        The pre-processed correlation matrix. No more processing of the
        matrix should have to be done here, i.e. it should already have
        filtered the correlations to the proper SD ranges and removed the
        genes that are not applicable for this explanation,
        self correlations should also have been removed.
    indranet : nx.DiGraph
        The graph representation of the indra network. Each edge should
        have an attribute named 'statements' containing a list of sources
        supporting that edge. If signed search, indranet is expected to be an
        nx.MultiDiGraph with edges keys by (gene, gene, sign) tuples.
    sd_range : tuple[float]
        The SD ranges that the corr_z is filtered to
    script_settings : Dict[str, Union[str, int, float]]
        Dictionary with script settings for the purpose of book keeping
    graph_filepath : str
        File path to the graph used
    z_corr_filepath : str
        File path to the correlation matrix used
    reactome_filepath : Optional[str]
        File path to the reactome data
    subset_list : Optional[List[Union[str, int]]]
        If True, check all combinations of off-diagonal values from the
        correlation matrix, i.e. check both (a, b) and (b, a). Default: False.

    Returns
    -------
    depmap_analysis.util.statistics.DepMapExplainer
        An instance of the DepMapExplainer class containing the explanations
        for the correlations.
    """
    # Map each expl type to a function that handles that explanation
    if not expl_funcs:
        # No function names provided, use all explanation functions
        logger.info('All explanation types used')
        expl_types = {
            funcname_to_colname[func_name]: func
            for func_name, func in expl_functions.items()
        }
    else:
        # Map function names to functions, check if
        expl_types = {}
        for func_name in expl_funcs:
            if func_name not in expl_functions:
                logger.warning(f'{func_name} does not map to a registered '
                               f'explanation function. Allowed functions '
                               f'{", ".join(expl_functions.keys())}')
            else:
                expl_types[funcname_to_colname[func_name]] = \
                    expl_functions[func_name]

    if not len(expl_types):
        raise ValueError('No explanation functions provided')

    bool_columns = ('not_in_graph', 'explained') + tuple(expl_types.keys())
    stats_columns = id_columns + bool_columns
    expl_cols = expl_columns
    apriori_explained = apriori_explained or {}

    logger.info(f'Doing correlation matching with {graph_type} graph')

    # Get options
    if allowed_ns is not None:
        allowed_ns = {n.lower() for n in allowed_ns}
        logger.info('Only allowing the following namespaces: %s' %
                    ', '.join(allowed_ns))

    if allowed_sources is not None:
        allowed_sources = {s.lower() for s in allowed_sources}
        logger.info('Only allowing the following sources: %s' %
                    ', '.join(allowed_sources))

    is_a_part_of = is_a_part_of or []

    # Try to get dates of files from file names and file info
    ymd_now = datetime.now().strftime('%Y%m%d')
    indra_date = indra_date or ymd_now
    depmap_date = depmap_date or ymd_now

    logger.info('Calculating number of pairs to check...')
    estim_pairs = get_pairs(corr_z, subset_list=subset_list)
    logger.info(f'Starting workers at {datetime.now().strftime("%H:%M:%S")} '
                f'with about {estim_pairs} pairs to check')
    tstart = time()

    with mp.Pool() as pool:
        MAX_SUB = 512
        n_sub = min(n_chunks, MAX_SUB)
        chunksize = get_chunk_size(n_sub, estim_pairs)

        # Pick one more so we don't do more than MAX_SUB
        chunksize += 1 if n_sub == MAX_SUB else 0
        chunk_iter = batch_iter(iterator=corr_matrix_to_generator(
            corr_z, subset_list=subset_list),
                                batch_size=chunksize,
                                return_func=list)
        for chunk in chunk_iter:
            pool.apply_async(
                func=_match_correlation_body,
                # args should match the args for func
                args=(chunk, expl_types, stats_columns, expl_cols,
                      bool_columns, graph_type, apriori_explained, allowed_ns,
                      allowed_sources, is_a_part_of, immediate_only,
                      return_unexplained, reactome_dict),
                callback=success_callback,
                error_callback=error_callback)

        logger.info('Done submitting work to pool workers')
        pool.close()
        pool.join()

    print(f'Execution time: {time() - tstart} seconds')
    print(f'Done at {datetime.now().strftime("%H:%M:%S")}')

    # Here initialize a DepMapExplainer and append the result for the
    # different processes
    explainer = DepMapExplainer(stats_columns=stats_columns,
                                expl_columns=expl_columns,
                                graph_filepath=graph_filepath,
                                z_corr_filepath=z_corr_filepath,
                                reactome_filepath=reactome_filepath,
                                info={
                                    'indra_network_date': indra_date,
                                    'depmap_date': depmap_date,
                                    'sd_range': sd_range,
                                    'graph_type': graph_type,
                                    **(info or {})
                                },
                                script_settings=script_settings)

    logger.info(f'Generating DepMapExplainer with output from '
                f'{len(output_list)} results')
    for stats_dict, expl_dict in output_list:
        explainer.stats_df = explainer.stats_df.append(other=pd.DataFrame(
            data=stats_dict))
        explainer.expl_df = explainer.expl_df.append(other=pd.DataFrame(
            data=expl_dict))

    return explainer
예제 #23
0
    def filter_text_content(self, db, tc_data):
        """Link Text Content entries to corresponding Text Refs and filter
        out entries already in the database."""
        if not len(tc_data):
            return []
        logger.info("Beginning to filter text content...")
        tr_list = []
        # Step 1: Build a dictionary matching IDs to text ref objects
        for ix, tc_batch in enumerate(batch_iter(tc_data, 5000)):
            # Get the sets of IDs for this batch
            # Only use the generator once!
            ids = [(tc['pmid'], tc['pmcid'], tc['doi']) for tc in tc_batch]
            pmids, pmcids, dois = list(zip(*ids))
            # Remove any Nones and convert to sets
            pmid_set = set([i for i in pmids if i is not None])
            pmcid_set = set([i for i in pmcids if i is not None])
            doi_set = set([i for i in dois if i is not None])
            # Get all TextRefs for the CORD19 IDs
            logger.debug("Getting text refs for CORD19 articles")
            tr_list += db.select_all(
                db.TextRef,
                sql_exp.or_(db.TextRef.pmid_in(pmid_set, filter_ids=True),
                            db.TextRef.pmcid_in(pmcid_set, filter_ids=True),
                            db.TextRef.doi_in(doi_set, filter_ids=True)))
        # Next, build dictionaries mapping IDs back to TextRef objects so
        # that we can link records in tc_data to TextRefs
        trs_by_doi = defaultdict(set)
        trs_by_pmc = defaultdict(set)
        trs_by_pmid = defaultdict(set)
        for tr in tr_list:
            if tr.doi:
                trs_by_doi[tr.doi].add(tr)
            if tr.pmcid:
                trs_by_pmc[tr.pmcid].add(tr)
            if tr.pmid:
                trs_by_pmid[tr.pmid].add(tr)

        # Now, build a new dictionary of text content including the TRIDs
        # rather than pmid/pmcid/doi
        # A list of dictionaries each containing: tr_id, source, format and
        # text_type
        flawed_tcs = set()
        tc_data_by_tr = []
        for tc_entry in tc_data:
            by_tr_entry = {}
            for field in ('source', 'format', 'text_type', 'content'):
                by_tr_entry[field] = tc_entry[field]
            tr_ids_for_tc = set()
            for id_type, trs_by_id in (('pmid', trs_by_pmid),
                                       ('pmcid', trs_by_pmc), ('doi',
                                                               trs_by_doi)):
                tr_set = trs_by_id.get(tc_entry[id_type])
                if tr_set is not None:
                    # assert len(tr_set) == 1
                    if len(tr_set) != 1:
                        logger.warning(
                            '%s %s is associated with multiple TextRefs: %s' %
                            (id_type, tc_entry[id_type], tr_set))
                        continue
                    tr = list(tr_set)[0]
                    tr_ids_for_tc.add(tr.id)
            # Because this function is called using tc_data that has already
            # been filtered by text ref, we should always get unambiguous
            # matches to text_refs here.
            if len(tr_ids_for_tc) != 1:
                log_entry = (tc_entry['pmid'], tc_entry['pmcid'],
                             tc_entry['doi'], tc_entry['cord_uid'],
                             tuple(tr_ids_for_tc))
                logger.warning('Missing or ambiguous match to text ref: %s' %
                               str(log_entry))
                flawed_tcs.add(log_entry)
            else:
                tr_id = list(tr_ids_for_tc)[0]
                by_tr_entry['trid'] = tr_id
                tc_data_by_tr.append(by_tr_entry)

        # Step 2: Get existing Text Content objects corresponding to the
        # the given text refs with the same format and source.
        # This should be a very small list, in general.
        existing_tc_records = []
        for source, text_type in (('cord19_abstract',
                                   'abstract'), ('cord19_pmc_xml', 'fulltext'),
                                  ('cord19_pdf', 'fulltext')):
            logger.debug('Finding existing text content from db for '
                         'source type %s' % source)
            tc_by_source = [
                tc_entry for tc_entry in tc_data_by_tr
                if tc_entry['source'] == source
            ]
            existing_tcs = db.select_all(
                db.TextContent,
                db.TextContent.text_ref_id.in_([
                    tc['trid'] for tc in tc_by_source
                ]), db.TextContent.source == source,
                db.TextContent.format == 'text',
                db.TextContent.text_type == text_type)
            # Reformat Text Content objects to list of tuples
            existing_tc_records += [(tc.text_ref_id, tc.source, tc.format,
                                     tc.text_type) for tc in existing_tcs]
            logger.debug("Found %d existing records on the db for %s." %
                         (len(existing_tc_records), source))
        # Convert list of dicts into a list of tuples
        tc_records = []
        for tc_entry in tc_data_by_tr:
            tc_records.append(
                (tc_entry['trid'], tc_entry['source'], tc_entry['format'],
                 tc_entry['text_type'], tc_entry['content']))

        # Filter the TC records to exclude
        filtered_tc_records = [
            rec for rec in tc_records if rec[:-1] not in existing_tc_records
        ]
        logger.info("Finished filtering the text content.")
        return list(set(filtered_tc_records)), flawed_tcs
예제 #24
0
def main():
    arg_parser = get_parser()
    args = arg_parser.parse_args()

    s3 = boto3.client('s3')
    s3_log_prefix = get_s3_job_log_prefix(args.s3_base, args.job_name)
    logger.info("Using log prefix \"%s\"" % s3_log_prefix)
    id_list_key = args.s3_base + 'id_list'
    logger.info("Looking for id list on s3 at \"%s\"" % id_list_key)
    try:
        id_list_obj = s3.get_object(Bucket=bucket_name, Key=id_list_key)
    except botocore.exceptions.ClientError as e:
        # Handle a missing object gracefully
        if e.response['Error']['Code'] == 'NoSuchKey':
            logger.info('Could not find PMID list file at %s, exiting' %
                        id_list_key)
            sys.exit(1)
        # If there was some other kind of problem, re-raise the exception
        else:
            raise e

    # Get the content from the object
    id_list_str = id_list_obj['Body'].read().decode('utf8').strip()
    id_str_list = id_list_str.splitlines()[args.start_index:args.end_index]
    random.shuffle(id_str_list)
    tcids = [int(line.strip()) for line in id_str_list]

    # Get the reader objects
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    kwargs = {'base_dir': args.out_dir, 'n_proc': args.num_cores}
    readers = construct_readers(args.readers, **kwargs)

    # Record the reader versions used in this run.
    reader_versions = {}
    for reader in readers:
        reader_versions[reader.name] = reader.get_version()
    s3.put_object(Bucket=bucket_name,
                  Key=get_s3_reader_version_loc(args.s3_base, args.job_name),
                  Body=json.dumps(reader_versions))

    # Some combinations of options don't make sense:
    forbidden_combos = [('all', 'unread'), ('none', 'unread'),
                        ('none', 'none')]
    assert (args.read_mode, args.rslt_mode) not in forbidden_combos, \
        ("The combination of reading mode %s and statement mode %s is not "
         "allowed." % (args.reading_mode, args.rslt_mode))

    # Get a handle for the database
    if args.test:
        from indra_db.tests.util import get_temp_db
        db = get_temp_db(clear=True)
    else:
        db = None

    # Read everything ========================================
    if args.batch is None:
        run_reading(readers,
                    tcids,
                    verbose=True,
                    db=db,
                    reading_mode=args.read_mode,
                    rslt_mode=args.rslt_mode)
    else:
        for tcid_batch in batch_iter(tcids, args.batch):
            run_reading(readers,
                        tcid_batch,
                        verbose=True,
                        db=db,
                        reading_mode=args.read_mode,
                        rslt_mode=args.rslt_mode)

    # Preserve the sparser logs
    contents = os.listdir('.')
    logger.info("Checking for any log files to cache:\n" + '\n'.join(contents))
    sparser_logs = []
    trips_logs = []
    for fname in contents:
        # Check if this file is a sparser log
        if fname.startswith('sparser') and fname.endswith('log'):
            sparser_logs.append(fname)
        elif is_trips_datestring(fname):
            for sub_fname in os.listdir(fname):
                if sub_fname.endswith('.log') or sub_fname.endswith('.err'):
                    trips_logs.append(os.path.join(fname, sub_fname))

    _dump_logs_to_s3(s3, s3_log_prefix, 'sparser', sparser_logs)
    _dump_logs_to_s3(s3, s3_log_prefix, 'trips', trips_logs)
    return
예제 #25
0
def test_iterator_slicing():
    size = 50
    a = _gen_sym_df(size)

    pairs = set()
    n = 0
    for n in range(size):
        k = 0
        row, col = _get_off_diag_pair(size)
        while (row, col) in pairs:
            row, col = _get_off_diag_pair(size)
            k += 1
            if k > 1000:
                print('Too many while iterations, breaking')
                break
        if k > 1000:
            break
        a.iloc[row, col] = np.nan
        a.iloc[col, row] = np.nan
        pairs.add((row, col))
        pairs.add((col, row))

    pairs_removed = n + 1

    # Assert that we're correct so far
    assert (size**2 - size - 2 * pairs_removed) / 2 == get_pairs(a)

    # Check that the iterator slicing for multiprocessing runs through all
    # the pairs

    # Get total pairs available
    total_pairs = get_pairs(a)

    # Chunks wanted
    chunks_wanted = 10

    chunksize = get_chunk_size(chunks_wanted, total_pairs)

    chunk_iter = batch_iter(iterator=corr_matrix_to_generator(a),
                            batch_size=chunksize,
                            return_func=list)

    pair_count = 0
    chunk_ix = 0
    for chunk_ix, list_of_pairs in enumerate(chunk_iter):
        pair_count += len([t for t in list_of_pairs if t is not None])

    # Were all pairs looped?
    assert pair_count == total_pairs, \
        f'pair_count={pair_count} total_pairs={total_pairs}'
    # Does the number of loop iterations correspond to the number of chunks
    # wanted?
    assert chunk_ix + 1 == chunks_wanted, \
        f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}'

    # Redo the same with subset of names
    name_subset = list(
        np.random.choice(a.columns.values, size=size // 3, replace=False))
    # Add a name that does not exist in the original df
    name_subset.append(size + 2)

    # Get total pairs available
    total_pairs_permute = get_pairs(a, subset_list=name_subset)

    # Chunks wanted
    chunks_wanted = 10

    chunksize = get_chunk_size(chunks_wanted, total_pairs_permute)

    chunk_iter = batch_iter(iterator=corr_matrix_to_generator(
        a, subset_list=name_subset),
                            batch_size=chunksize,
                            return_func=list)

    pair_count = 0
    chunk_ix = 0
    for chunk_ix, list_of_pairs in enumerate(chunk_iter):
        pair_count += len([t for t in list_of_pairs if t is not None])

    # Were all pairs looped?
    assert pair_count == total_pairs_permute, \
        f'pair_count={pair_count} total_pairs={total_pairs_permute}'

    # Does the number of loop iterations correspond to the number of chunks
    # wanted?
    assert chunk_ix + 1 == chunks_wanted, \
        f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}'
예제 #26
0
def reground_texts(texts,
                   ont_yml,
                   webservice,
                   topk=10,
                   is_canonicalized=False,
                   filter=True,
                   cache_path=None):
    """Ground concept texts given an ontology with an Eidos web service.

    Parameters
    ----------
    texts : list[str]
        A list of concept texts to ground.
    ont_yml : str
        A serialized YAML string representing the ontology.
    webservice : str
        The address where the Eidos web service is running, e.g.,
        http://localhost:9000.
    topk : Optional[int]
        The number of top scoring groundings to return. Default: 10
    is_canonicalized : Optional[bool]
        If True, the texts are assumed to be canonicalized. If False,
        Eidos will canonicalize the texts which yields much better groundings
        but is slower. Default: False
    filter : Optional[bool]
        If True, Eidos filters the ontology to remove determiners from examples
        and other similar operations. Should typically be set to True.
        Default: True

    Returns
    -------
    dict
        A JSON dict of the results from the Eidos webservice.
    """
    all_results = []
    grounding_cache = {}
    if cache_path:
        if os.path.exists(cache_path):
            with open(cache_path, 'rb') as fh:
                grounding_cache = pickle.load(fh)
                logger.info('Loaded %d groundings from cache' %
                            len(grounding_cache))
    texts_to_ground = list(set(texts) - set(grounding_cache.keys()))
    logger.info('Grounding a total of %d texts' % len(texts_to_ground))
    for text_batch in tqdm.tqdm(batch_iter(texts_to_ground,
                                           batch_size=500,
                                           return_func=list),
                                total=math.ceil(len(texts_to_ground) / 500)):
        params = {
            'ontologyYaml': ont_yml,
            'texts': text_batch,
            'topk': topk,
            'isAlreadyCanonicalized': is_canonicalized,
            'filter': filter
        }
        res = requests.post('%s/reground' % webservice, json=params)
        res.raise_for_status()
        grounding_for_texts = grounding_dict_to_list(res.json())

        for txt, grounding in zip(text_batch, grounding_for_texts):
            grounding_cache[txt] = grounding

    all_results = [grounding_cache[txt] for txt in texts]
    if cache_path:
        with open(cache_path, 'wb') as fh:
            pickle.dump(grounding_cache, fh)
    return all_results
예제 #27
0
def get_statements(clauses,
                   count=1000,
                   do_stmt_count=False,
                   db=None,
                   preassembled=True,
                   with_support=False,
                   fix_refs=True,
                   with_evidence=True):
    """Select statements according to a given set of clauses.

    Parameters
    ----------
    clauses : list
        list of sqlalchemy WHERE clauses to pass to the filter query.
    count : int
        Number of statements to retrieve and process in each batch.
    do_stmt_count : bool
        Whether or not to perform an initial statement counting step to give
        more meaningful progress messages.
    db : :py:class:`DatabaseManager`
        Optionally specify a database manager that attaches to something
        besides the primary database, for example a local database instance.
    preassembled : bool
        If true, statements will be selected from the table of pre-assembled
        statements. Otherwise, they will be selected from the raw statements.
        Default is True.
    with_support : bool
        Choose whether to populate the supports and supported_by list
        attributes of the Statement objects. General results in slower queries.
    with_evidence : bool
        Choose whether or not to populate the evidence list attribute of the
        Statements. As with `with_support`, setting this to True will take
        longer.
    fix_refs : bool
        The paper refs within the evidence objects are not populated in the
        database, and thus must be filled using the relations in the database.
        If True (default), the `pmid` field of each Statement Evidence object
        is set to the correct PMIDs, or None if no PMID is available. If False,
        the `pmid` field defaults to the value populated by the reading
        system.

    Returns
    -------
    list of Statements from the database corresponding to the query.
    """
    warnings.warn(('This module is being taken out of service, as the tools '
                   'have become deprecated. Moreover, the service has been '
                   're-implemented to use newer tools as best as possible, '
                   'but some results may be subtly different.'),
                  DeprecationWarning)
    cnt = count
    if db is None:
        db = get_primary_db()

    stmts_tblname = 'pa_statements' if preassembled else 'raw_statements'

    if not preassembled:
        stmts = []
        q = db.filter_query(stmts_tblname, *clauses)
        if do_stmt_count:
            logger.info("Counting statements...")
            num_stmts = q.count()
            logger.info("Total of %d statements" % num_stmts)
        db_stmts = q.yield_per(cnt)
        for subset in batch_iter(db_stmts, cnt):
            stmts.extend(
                get_raw_stmts_frm_db_list(db,
                                          subset,
                                          with_sids=False,
                                          fix_refs=fix_refs))
            if do_stmt_count:
                logger.info("%d of %d statements" % (len(stmts), num_stmts))
            else:
                logger.info("%d statements" % len(stmts))
    else:
        logger.info("Getting preassembled statements.")
        if with_evidence:
            logger.info("Getting preassembled statements.")
            # Get pairs of pa statements with their linked raw statements
            clauses += [
                db.PAStatements.mk_hash == db.RawUniqueLinks.pa_stmt_mk_hash,
                db.RawStatements.id == db.RawUniqueLinks.raw_stmt_id
            ]
            pa_raw_stmt_pairs = \
                db.select_all([db.PAStatements, db.RawStatements],
                              *clauses, yield_per=cnt)
            stmt_dict = _process_pa_statement_res_wev(db,
                                                      pa_raw_stmt_pairs,
                                                      count=cnt,
                                                      fix_refs=fix_refs)
        else:
            # Get just pa statements without their supporting raw statement(s).
            pa_stmts = db.select_all(db.PAStatements, *clauses, yield_per=cnt)
            stmt_dict = _process_pa_statement_res_nev(pa_stmts, count=cnt)

        # Populate the supports/supported by fields.
        if with_support:
            get_support(stmt_dict, db=db)

        stmts = list(stmt_dict.values())
        logger.info("In all, there are %d pa statements." % len(stmts))

    return stmts
예제 #28
0
def get_statements(clauses, count=1000, do_stmt_count=False, db=None,
                   preassembled=True, with_support=False, fix_refs=True,
                   with_evidence=True):
    """Select statements according to a given set of clauses.

    Parameters
    ----------
    clauses : list
        list of sqlalchemy WHERE clauses to pass to the filter query.
    count : int
        Number of statements to retrieve and process in each batch.
    do_stmt_count : bool
        Whether or not to perform an initial statement counting step to give
        more meaningful progress messages.
    db : :py:class:`DatabaseManager`
        Optionally specify a database manager that attaches to something
        besides the primary database, for example a local database instance.
    preassembled : bool
        If true, statements will be selected from the table of pre-assembled
        statements. Otherwise, they will be selected from the raw statements.
        Default is True.
    with_support : bool
        Choose whether to populate the supports and supported_by list attributes
        of the Statement objects. General results in slower queries.
    with_evidence : bool
        Choose whether or not to populate the evidence list attribute of the
        Statements. As with `with_support`, setting this to True will take
        longer.
    fix_refs : bool
        The paper refs within the evidence objects are not populated in the
        database, and thus must be filled using the relations in the database.
        If True (default), the `pmid` field of each Statement Evidence object
        is set to the correct PMIDs, or None if no PMID is available. If False,
        the `pmid` field defaults to the value populated by the reading
        system.

    Returns
    -------
    list of Statements from the database corresponding to the query.
    """
    if db is None:
        db = get_primary_db()

    stmts_tblname = 'pa_statements' if preassembled else 'raw_statements'

    if not preassembled:
        stmts = []
        q = db.filter_query(stmts_tblname, *clauses)
        if do_stmt_count:
            logger.info("Counting statements...")
            num_stmts = q.count()
            logger.info("Total of %d statements" % num_stmts)
        db_stmts = q.yield_per(count)
        for subset in batch_iter(db_stmts, count):
            stmts.extend(get_raw_stmts_frm_db_list(db, subset, with_sids=False,
                                                   fix_refs=fix_refs))
            if do_stmt_count:
                logger.info("%d of %d statements" % (len(stmts), num_stmts))
            else:
                logger.info("%d statements" % len(stmts))
    else:
        logger.info("Getting preassembled statements.")
        if with_evidence:
            logger.info("Getting preassembled statements.")
            # Get pairs of pa statements with their linked raw statements
            clauses += db.join(db.PAStatements, db.RawStatements)
            pa_raw_stmt_pairs = \
                db.select_all([db.PAStatements, db.RawStatements],
                              *clauses, yield_per=count)

            # Iterate over the batches to create the statement objects.
            stmt_dict = {}
            ev_dict = {}
            raw_stmt_dict = {}
            total_ev = 0
            for stmt_pair_batch in batch_iter(pa_raw_stmt_pairs, count):
                # Instantiate the PA statement objects, and record the uuid
                # evidence (raw statement) links.
                raw_stmt_objs = []
                for pa_stmt_db_obj, raw_stmt_db_obj in stmt_pair_batch:
                    k = pa_stmt_db_obj.mk_hash
                    if k not in stmt_dict.keys():
                        stmt_dict[k] = _get_statement_object(pa_stmt_db_obj)
                        ev_dict[k] = [raw_stmt_db_obj.id,]
                    else:
                        ev_dict[k].append(raw_stmt_db_obj.id)
                    raw_stmt_objs.append(raw_stmt_db_obj)
                    total_ev += 1

                logger.info("Up to %d pa statements, with %d pieces of "
                            "evidence in all." % (len(stmt_dict), total_ev))

                # Instantiate the raw statements.
                raw_stmt_sid_tpls = get_raw_stmts_frm_db_list(db, raw_stmt_objs,
                                                              fix_refs,
                                                              with_sids=True)
                raw_stmt_dict.update({sid: s for sid, s in raw_stmt_sid_tpls})
                logger.info("Processed %d raw statements."
                            % len(raw_stmt_sid_tpls))

            # Attach the evidence
            logger.info("Inserting evidence.")
            for k, sid_list in ev_dict.items():
                stmt_dict[k].evidence = [raw_stmt_dict[sid].evidence[0]
                                         for sid in sid_list]
        else:
            # Get just pa statements without their supporting raw statement(s).
            pa_stmts = db.select_all(db.PAStatements, *clauses, yield_per=count)

            # Iterate over the batches to create the statement objects.
            stmt_dict = {}
            for stmt_pair_batch in batch_iter(pa_stmts, count):
                # Instantiate the PA statement objects.
                for pa_stmt_db_obj in stmt_pair_batch:
                    k = pa_stmt_db_obj.mk_hash
                    if k not in stmt_dict.keys():
                        stmt_dict[k] = _get_statement_object(pa_stmt_db_obj)

                logger.info("Up to %d pa statements in all." % len(stmt_dict))

        # Populate the supports/supported by fields.
        if with_support:
            logger.info("Populating support links.")
            support_links = db.select_all(
                [db.PASupportLinks.supported_mk_hash,
                 db.PASupportLinks.supporting_mk_hash],
                or_(db.PASupportLinks.supported_mk_hash.in_(stmt_dict.keys()),
                    db.PASupportLinks.supporting_mk_hash.in_(stmt_dict.keys()))
                )
            for supped_hash, supping_hash in set(support_links):
                if supped_hash == supping_hash:
                    assert False, 'Self-support found on-load.'
                supped_stmt = stmt_dict.get(supped_hash,
                                            Unresolved(shallow_hash=supped_hash))
                supping_stmt = stmt_dict.get(supping_hash,
                                             Unresolved(shallow_hash=supping_hash))
                supped_stmt.supported_by.append(supping_stmt)
                supping_stmt.supports.append(supped_stmt)

        stmts = list(stmt_dict.values())
        logger.info("In all, there are %d pa statements." % len(stmts))

    return stmts
예제 #29
0
def test_iterator_slicing():
    size = 50
    a, pairs_removed = _get_df_w_nan(size)

    # Assert that we're correct so far

    # Get total pairs available:
    total_pairs = get_pairs(a)

    # all items - diagonal - all removed items off diagonal
    assert (size**2 - size - 2*pairs_removed) / 2 == total_pairs

    # Check that the iterator slicing for multiprocessing runs through all
    # the pairs

    # Chunks wanted
    chunks_wanted = 10

    chunksize = get_chunk_size(chunks_wanted, total_pairs)

    chunk_iter = batch_iter(iterator=corr_matrix_to_generator(a),
                            batch_size=chunksize, return_func=list)

    pair_count = 0
    chunk_ix = 0
    for chunk_ix, list_of_pairs in enumerate(chunk_iter):
        pair_count += len([(t[0][0], t[0][1], t[1]) for t in
                           list_of_pairs if t is not None])

    # Were all pairs looped?
    assert pair_count == total_pairs, \
        f'pair_count={pair_count} total_pairs={total_pairs}'
    # Does the number of loop iterations correspond to the number of chunks
    # wanted?
    assert chunk_ix + 1 == chunks_wanted, \
        f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}'

    # Redo the same with subset of names
    name_subset = list(np.random.choice(a.columns.values,
                                        size=size // 3,
                                        replace=False))
    # Add a name that does not exist in the original df
    name_subset.append(size+2)

    # Get total pairs available
    total_pairs_permute = get_pairs(a, subset_list=name_subset)

    # Chunks wanted
    chunks_wanted = 10

    chunksize = get_chunk_size(chunks_wanted, total_pairs_permute)

    chunk_iter = batch_iter(
        iterator=corr_matrix_to_generator(a, subset_list=name_subset),
        batch_size=chunksize,
        return_func=list
    )

    pair_count = 0
    chunk_ix = 0
    for chunk_ix, list_of_pairs in enumerate(chunk_iter):
        pair_count += len([(t[0][0], t[0][1], t[1]) for t in
                           list_of_pairs if t is not None])

    # Were all pairs looped?
    assert pair_count == total_pairs_permute, \
        f'pair_count={pair_count} total_pairs={total_pairs_permute}'

    # Does the number of loop iterations correspond to the number of chunks
    # wanted?
    assert chunk_ix + 1 == chunks_wanted, \
        f'chunk_ix+1={chunk_ix + 1}, chunks_wanted={chunks_wanted}'