def _check_statement_distillation(num_stmts):
    db = get_pa_loaded_db(num_stmts)
    assert db is not None, "Test was broken. Got None instead of db insance."
    stmts = db_util.distill_stmts(db, get_full_stmts=True)
    assert len(stmts), "Got zero statements."
    assert isinstance(list(stmts)[0], Statement), type(list(stmts)[0])
    stmt_ids = db_util.distill_stmts(db)
    assert len(stmts) == len(stmt_ids), \
        "stmts: %d, stmt_ids: %d" % (len(stmts), len(stmt_ids))
    assert isinstance(list(stmt_ids)[0], int), type(list(stmt_ids)[0])
    stmts_p = db_util.distill_stmts(db)
    assert len(stmts_p) == len(stmt_ids)
    stmt_ids_p = db_util.distill_stmts(db)
    assert stmt_ids_p == stmt_ids
Exemple #2
0
def get_raw_stmts(tr_dicts, date_limit=None):
    """Return all raw stmts in INDRA DB for a given set of TextRef IDs.

    Parameters
    ----------
    tr_dicts : dict of text ref information
        Keys are text ref IDs (ints) mapped to dictionaries of text ref
        metadata.

    date_limit : Optional[int]
        A number of days to check the readings back.

    Returns
    -------
    list of stmts
        Raw INDRA Statements retrieved from the INDRA DB.
    """
    # Get raw statement IDs from the DB for the given TextRefs
    db = get_primary_db()
    # Get statements for the given text refs
    text_ref_ids = list(tr_dicts.keys())
    print(f"Distilling statements for {len(text_ref_ids)} TextRefs")
    start = time.time()
    clauses = [
        db.TextRef.id.in_(text_ref_ids),
        db.TextContent.text_ref_id == db.TextRef.id,
        db.Reading.text_content_id == db.TextContent.id,
        db.RawStatements.reading_id == db.Reading.id
    ]
    if date_limit:
        start_date = (datetime.datetime.utcnow() -
                      datetime.timedelta(days=date_limit))
        print(f'Limiting to stmts from readings in the last {date_limit} days')
        clauses.append(db.Reading.create_date > start_date)
    db_stmts = distill_stmts(db, get_full_stmts=True, clauses=clauses)
    # Group lists of statements by the IDs TextRef that they come from
    stmts_by_trid = {}
    for stmt in db_stmts:
        trid = stmt.evidence[0].text_refs['TRID']
        if trid not in stmts_by_trid:
            stmts_by_trid[trid] = [stmt]
        else:
            stmts_by_trid[trid].append(stmt)
    # For every statement, update the text ref dictionary of the evidence
    # object with the aligned DB/CORD19 dictionaries obtained from the
    # function cord19_metadata_for_trs:
    stmts_flat = []
    for tr_id, stmt_list in stmts_by_trid.items():
        tr_dict = tr_dicts[tr_id]
        if tr_dict:
            for stmt in stmt_list:
                stmt.evidence[0].text_refs.update(tr_dict)
        stmts_flat += stmt_list
    elapsed = time.time() - start
    print(f"{elapsed} seconds")
    return stmts_flat
Exemple #3
0
def get_raw_statements_for_pmids(pmids, mode='all', batch_size=100):
    """Return EmmaaStatements based on extractions from given PMIDs.

    Parameters
    ----------
    pmids : set or list of str
        A set of PMIDs to find raw INDRA Statements for in the INDRA DB.
    mode : 'all' or 'distilled'
        The 'distilled' mode makes sure that the "best", non-redundant
        set of raw statements are found across potentially redundant text
        contents and reader versions. The 'all' mode doesn't do such
        distillation but is significantly faster.
    batch_size : Optional[int]
        Determines how many PMIDs to fetch statements for in each
        iteration. Default: 100.

    Returns
    -------
    dict
        A dict keyed by PMID with values INDRA Statements obtained
        from the given PMID.
    """
    db = get_db('primary')
    logger.info(f'Getting raw statements for {len(pmids)} PMIDs')
    all_stmts = defaultdict(list)
    for pmid_batch in tqdm.tqdm(batch_iter(pmids,
                                           return_func=set,
                                           batch_size=batch_size),
                                total=len(pmids) / batch_size):
        if mode == 'distilled':
            clauses = [
                db.TextRef.pmid.in_(pmid_batch),
                db.TextContent.text_ref_id == db.TextRef.id,
                db.Reading.text_content_id == db.TextContent.id,
                db.RawStatements.reading_id == db.Reading.id
            ]
            distilled_stmts = distill_stmts(db,
                                            get_full_stmts=True,
                                            clauses=clauses)
            for stmt in distilled_stmts:
                all_stmts[stmt.evidence[0].pmid].append(stmt)
        else:
            id_stmts = \
                get_raw_stmt_jsons_from_papers(pmid_batch, id_type='pmid',
                                               db=db)
            for pmid, stmt_jsons in id_stmts.items():
                all_stmts[pmid] += stmts_from_json(stmt_jsons)
    all_stmts = dict(all_stmts)
    return all_stmts
    def supplement_corpus(self, db, continuing=False, dups_file=None):
        """Update the table of preassembled statements.

        This method will take any new raw statements that have not yet been
        incorporated into the preassembled table, and use them to augment the
        preassembled table.

        The resulting updated table is indistinguishable from the result you
        would achieve if you had simply re-run preassembly on _all_ the
        raw statements.
        """
        self.__tag = 'supplement'

        dup_handling = dups_file if dups_file else 'delete'

        pickle_stashes = []
        last_update = self._get_latest_updatetime(db)
        start_date = datetime.utcnow()
        self._log("Latest update was: %s" % last_update)

        # Get the new statements...
        self._log("Loading info about the existing state of preassembly. "
                  "(This may take a little time)")
        new_id_stash = 'new_ids.pkl'
        pickle_stashes.append(new_id_stash)
        if continuing and path.exists(new_id_stash):
            self._log("Loading new statement ids from cache...")
            with open(new_id_stash, 'rb') as f:
                new_ids = pickle.load(f)
        else:
            new_ids = self._get_new_stmt_ids(db)

            # Stash the new ids in case we need to pick up where we left off.
            with open(new_id_stash, 'wb') as f:
                pickle.dump(new_ids, f)

        # Weed out exact duplicates.
        dist_stash = 'stmt_ids.pkl'
        pickle_stashes.append(dist_stash)
        if continuing and path.exists(dist_stash):
            self._log("Loading distilled statement ids from cache...")
            with open(dist_stash, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            stmt_ids = distill_stmts(db,
                                     num_procs=self.n_proc,
                                     get_full_stmts=False,
                                     handle_duplicates=dup_handling)
            with open(dist_stash, 'wb') as f:
                pickle.dump(stmt_ids, f)

        new_stmt_ids = new_ids & stmt_ids

        # Get the set of new unique statements and link to any new evidence.
        old_mk_set = {mk for mk, in db.select_all(db.PAStatements.mk_hash)}
        self._log("Found %d old pa statements." % len(old_mk_set))
        new_mk_stash = 'new_mk_set.pkl'
        pickle_stashes.append(new_mk_stash)
        if continuing and path.exists(new_mk_stash):
            self._log("Loading hashes for new pa statements from cache...")
            with open(new_mk_stash, 'rb') as f:
                stash_dict = pickle.load(f)
            start_date = stash_dict['start']
            end_date = stash_dict['end']
            new_mk_set = stash_dict['mk_set']
        else:
            new_mk_set = self._get_unique_statements(db, new_stmt_ids,
                                                     len(new_stmt_ids),
                                                     old_mk_set)
            end_date = datetime.utcnow()
            with open(new_mk_stash, 'wb') as f:
                pickle.dump(
                    {
                        'start': start_date,
                        'end': end_date,
                        'mk_set': new_mk_set
                    }, f)
        if continuing:
            self._log("Original old mk set: %d" % len(old_mk_set))
            old_mk_set = old_mk_set - new_mk_set
            self._log("Adjusted old mk set: %d" % len(old_mk_set))

        self._log("Found %d new pa statements." % len(new_mk_set))

        # If we are continuing, check for support links that were already found.
        support_link_stash = 'new_support_links.pkl'
        pickle_stashes.append(support_link_stash)
        if continuing and path.exists(support_link_stash):
            with open(support_link_stash, 'rb') as f:
                status_dict = pickle.load(f)
                existing_links = status_dict['existing links']
                npa_done = status_dict['ids done']
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()
            npa_done = set()

        # Now find the new support links that need to be added.
        new_support_links = set()
        batching_args = (self.batch_size, db.PAStatements.json,
                         db.PAStatements.create_date >= start_date,
                         db.PAStatements.create_date <= end_date)
        npa_json_iter = db.select_all_batched(
            *batching_args, order_by=db.PAStatements.create_date)
        try:
            for outer_offset, npa_json_batch in npa_json_iter:
                npa_batch = [
                    _stmt_from_json(s_json) for s_json, in npa_json_batch
                ]

                # Compare internally
                self._log("Getting support for new pa at offset %d." %
                          outer_offset)
                some_support_links = self._get_support_links(npa_batch)

                # Compare against the other new batch statements.
                diff_new_mks = new_mk_set - {shash(s) for s in npa_batch}
                other_npa_json_iter = db.select_all_batched(
                    *batching_args,
                    order_by=db.PAStatements.create_date,
                    skip_offset=outer_offset)
                for inner_offset, other_npa_json_batch in other_npa_json_iter:
                    other_npa_batch = [
                        _stmt_from_json(s_json)
                        for s_json, in other_npa_json_batch
                    ]
                    split_idx = len(npa_batch)
                    full_list = npa_batch + other_npa_batch
                    self._log("Comparing offset %d to offset %d of other new "
                              "statements." % (outer_offset, inner_offset))
                    some_support_links |= \
                        self._get_support_links(full_list, split_idx=split_idx,
                                                poolsize=self.n_proc)

                # Compare against the existing statements.
                opa_json_iter = db.select_all_batched(
                    self.batch_size, db.PAStatements.json,
                    db.PAStatements.create_date < start_date)
                for old_offset, opa_json_batch in opa_json_iter:
                    opa_batch = [
                        _stmt_from_json(s_json) for s_json, in opa_json_batch
                    ]
                    split_idx = len(npa_batch)
                    full_list = npa_batch + opa_batch
                    self._log("Comparing new offset %d to offset %d of old "
                              "statements." % (outer_offset, old_offset))
                    some_support_links |= \
                        self._get_support_links(full_list, split_idx=split_idx,
                                                poolsize=self.n_proc)

                # Although there are generally few support links, copying as we
                # go allows work to not be wasted.
                new_support_links |= (some_support_links - existing_links)
                self._log("Copying batch of %d support links into db." %
                          len(new_support_links))
                db.copy('pa_support_links', new_support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= new_support_links
                npa_done |= {s.get_hash(shallow=True) for s in npa_batch}
                new_support_links = set()
                with open(support_link_stash, 'wb') as f:
                    pickle.dump(
                        {
                            'existing links': existing_links,
                            'ids done': npa_done
                        }, f)

            # Insert any remaining support links.
            if new_support_links:
                self._log("Copying batch final of %d support links into db." %
                          len(new_support_links))
                db.copy('pa_support_links', new_support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= new_support_links
        except Exception:
            logger.info("Stashing support links found so far.")
            if new_support_links:
                with open(support_link_stash, 'wb') as f:
                    pickle.dump(existing_links, f)
            raise

        # Remove all the caches so they can't be picked up accidentally later.
        for cache in pickle_stashes:
            if path.exists(cache):
                remove(cache)

        return True
    def create_corpus(self, db, continuing=False, dups_file=None):
        """Initialize the table of preassembled statements.

        This method will find the set of unique knowledge represented in the
        table of raw statements, and it will populate the table of preassembled
        statements (PAStatements/pa_statements), while maintaining links between
        the raw statements and their unique (pa) counterparts. Furthermore, the
        refinement/support relationships between unique statements will be found
        and recorded in the PASupportLinks/pa_support_links table.

        For more detail on preassembly, see indra/preassembler/__init__.py
        """
        self.__tag = 'create'

        dup_handling = dups_file if dups_file else 'delete'

        # Get filtered statement ID's.
        sid_cache_fname = path.join(HERE, 'stmt_id_cache.pkl')
        if continuing and path.exists(sid_cache_fname):
            with open(sid_cache_fname, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            # Get the statement ids.
            stmt_ids = distill_stmts(db,
                                     num_procs=self.n_proc,
                                     handle_duplicates=dup_handling)
            with open(sid_cache_fname, 'wb') as f:
                pickle.dump(stmt_ids, f)

        # Handle the possibility we're picking up after an earlier job...
        done_pa_ids = set()
        if continuing:
            self._log("Getting set of statements already de-duplicated...")
            link_resp = db.select_all([
                db.RawUniqueLinks.raw_stmt_id,
                db.RawUniqueLinks.pa_stmt_mk_hash
            ])
            if link_resp:
                checked_raw_stmt_ids, pa_stmt_hashes = \
                    zip(*db.select_all([db.RawUniqueLinks.raw_stmt_id,
                                        db.RawUniqueLinks.pa_stmt_mk_hash]))
                stmt_ids -= set(checked_raw_stmt_ids)
                done_pa_ids = set(pa_stmt_hashes)
                self._log("Found %d preassembled statements already done." %
                          len(done_pa_ids))

        # Get the set of unique statements
        self._get_unique_statements(db, stmt_ids, len(stmt_ids), done_pa_ids)

        # If we are continuing, check for support links that were already found.
        if continuing:
            self._log("Getting pre-existing links...")
            db_existing_links = db.select_all([
                db.PASupportLinks.supporting_mk_hash,
                db.PASupportLinks.supporting_mk_hash
            ])
            existing_links = {tuple(res) for res in db_existing_links}
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()

        # Now get the support links between all batches.
        support_links = set()
        for i, outer_batch in enumerate(self._pa_batch_iter(db)):
            # Get internal support links
            self._log('Getting internal support links outer batch %d.' % i)
            some_support_links = self._get_support_links(outer_batch,
                                                         poolsize=self.n_proc)
            outer_mk_hashes = {shash(s) for s in outer_batch}

            # Get links with all other batches
            ib_iter = self._pa_batch_iter(db, ex_mks=outer_mk_hashes)
            for j, inner_batch in enumerate(ib_iter):
                split_idx = len(inner_batch)
                full_list = inner_batch + outer_batch
                self._log('Getting support compared to other batch %d of outer'
                          'batch %d.' % (j, i))
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx,
                                            poolsize=self.n_proc)

            # Add all the new support links
            support_links |= (some_support_links - existing_links)

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(support_links) >= self.batch_size:
                self._log("Copying batch of %d support links into db." %
                          len(support_links))
                db.copy('pa_support_links', support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= support_links
                support_links = set()

        # Insert any remaining support links.
        if support_links:
            self._log("Copying final batch of %d support links into db." %
                      len(support_links))
            db.copy('pa_support_links', support_links,
                    ('supported_mk_hash', 'supporting_mk_hash'))

        # Delete the pickle cache
        if path.exists(sid_cache_fname):
            remove(sid_cache_fname)

        return True
Exemple #6
0
    def _supplement_statements(self, db, continuing=False):
        """Supplement the preassembled statements with the latest content."""
        self.__tag = 'supplement'

        last_update = self._get_latest_updatetime(db)
        assert last_update is not None, \
            "The preassembly tables have not yet been initialized."
        start_date = datetime.utcnow()
        self._log("Latest update was: %s" % last_update)

        # Get the new statements...
        self._log("Loading info about the existing state of preassembly. "
                  "(This may take a little time)")
        new_id_stash = 'new_ids.pkl'
        self.pickle_stashes.append(new_id_stash)
        if continuing and path.exists(new_id_stash):
            self._log("Loading new statement ids from cache...")
            with open(new_id_stash, 'rb') as f:
                new_ids = pickle.load(f)
        else:
            new_ids = self._get_new_stmt_ids(db)

            # Stash the new ids in case we need to pick up where we left off.
            with open(new_id_stash, 'wb') as f:
                pickle.dump(new_ids, f)

        # Weed out exact duplicates.
        dist_stash = 'stmt_ids.pkl'
        self.pickle_stashes.append(dist_stash)
        if continuing and path.exists(dist_stash):
            self._log("Loading distilled statement ids from cache...")
            with open(dist_stash, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            stmt_ids = distill_stmts(db, get_full_stmts=False)
            with open(dist_stash, 'wb') as f:
                pickle.dump(stmt_ids, f)

        # Get discarded statements
        skip_ids = {i for i, in db.select_all(db.DiscardedStatements.stmt_id)}

        # Select only the good new statement ids.
        new_stmt_ids = new_ids & stmt_ids - skip_ids

        # Get the set of new unique statements and link to any new evidence.
        old_mk_set = {mk for mk, in db.select_all(db.PAStatements.mk_hash)}
        self._log("Found %d old pa statements." % len(old_mk_set))
        new_mk_stash = 'new_mk_set.pkl'
        self.pickle_stashes.append(new_mk_stash)
        if continuing and path.exists(new_mk_stash):
            self._log("Loading hashes for new pa statements from cache...")
            with open(new_mk_stash, 'rb') as f:
                stash_dict = pickle.load(f)
            start_date = stash_dict['start']
            end_date = stash_dict['end']
            new_mk_set = stash_dict['mk_set']
        else:
            new_mk_set = self._extract_and_push_unique_statements(
                db, new_stmt_ids, len(new_stmt_ids), old_mk_set)
            end_date = datetime.utcnow()
            with open(new_mk_stash, 'wb') as f:
                pickle.dump(
                    {
                        'start': start_date,
                        'end': end_date,
                        'mk_set': new_mk_set
                    }, f)
        if continuing:
            self._log("Original old mk set: %d" % len(old_mk_set))
            old_mk_set = old_mk_set - new_mk_set
            self._log("Adjusted old mk set: %d" % len(old_mk_set))

        self._log("Found %d new pa statements." % len(new_mk_set))
        self.__tag = 'Unpurposed'
        return start_date, end_date
Exemple #7
0
    def create_corpus(self, db, continuing=False):
        """Initialize the table of preassembled statements.

        This method will find the set of unique knowledge represented in the
        table of raw statements, and it will populate the table of preassembled
        statements (PAStatements/pa_statements), while maintaining links between
        the raw statements and their unique (pa) counterparts. Furthermore, the
        refinement/support relationships between unique statements will be found
        and recorded in the PASupportLinks/pa_support_links table.

        For more detail on preassembly, see indra/preassembler/__init__.py
        """
        self.__tag = 'create'

        if not continuing:
            # Make sure the discarded statements table is cleared.
            db.drop_tables([db.DiscardedStatements])
            db.create_tables([db.DiscardedStatements])
            db.session.close()
            db.grab_session()
        else:
            # Get discarded statements
            skip_ids = {
                i
                for i, in db.select_all(db.DiscardedStatements.stmt_id)
            }
            self._log("Found %d discarded statements from earlier run." %
                      len(skip_ids))

        # Get filtered statement ID's.
        sid_cache_fname = path.join(HERE, 'stmt_id_cache.pkl')
        if continuing and path.exists(sid_cache_fname):
            with open(sid_cache_fname, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            # Get the statement ids.
            stmt_ids = distill_stmts(db)
            with open(sid_cache_fname, 'wb') as f:
                pickle.dump(stmt_ids, f)

        # Handle the possibility we're picking up after an earlier job...
        done_pa_ids = set()
        if continuing:
            self._log("Getting set of statements already de-duplicated...")
            link_resp = db.select_all([
                db.RawUniqueLinks.raw_stmt_id,
                db.RawUniqueLinks.pa_stmt_mk_hash
            ])
            if link_resp:
                checked_raw_stmt_ids, pa_stmt_hashes = \
                    zip(*db.select_all([db.RawUniqueLinks.raw_stmt_id,
                                        db.RawUniqueLinks.pa_stmt_mk_hash]))
                stmt_ids -= set(checked_raw_stmt_ids)
                self._log("Found %d raw statements without links to unique." %
                          len(stmt_ids))
                stmt_ids -= skip_ids
                self._log("Found %d raw statements that still need to be "
                          "processed." % len(stmt_ids))
                done_pa_ids = set(pa_stmt_hashes)
                self._log("Found %d preassembled statements already done." %
                          len(done_pa_ids))

        # Get the set of unique statements
        self._extract_and_push_unique_statements(db, stmt_ids, len(stmt_ids),
                                                 done_pa_ids)

        # If we are continuing, check for support links that were already found
        if continuing:
            self._log("Getting pre-existing links...")
            db_existing_links = db.select_all([
                db.PASupportLinks.supporting_mk_hash,
                db.PASupportLinks.supporting_mk_hash
            ])
            existing_links = {tuple(res) for res in db_existing_links}
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()

        # Now get the support links between all batches.
        support_links = set()
        outer_iter = db.select_all_batched(self.batch_size,
                                           db.PAStatements.json,
                                           order_by=db.PAStatements.mk_hash)
        for outer_idx, outer_batch_jsons in outer_iter:
            outer_batch = [_stmt_from_json(sj) for sj, in outer_batch_jsons]
            # Get internal support links
            self._log('Getting internal support links outer batch %d.' %
                      outer_idx)
            some_support_links = self._get_support_links(outer_batch)

            # Get links with all other batches
            inner_iter = db.select_all_batched(
                self.batch_size,
                db.PAStatements.json,
                order_by=db.PAStatements.mk_hash,
                skip_idx=outer_idx)
            for inner_idx, inner_batch_jsons in inner_iter:
                inner_batch = [
                    _stmt_from_json(sj) for sj, in inner_batch_jsons
                ]
                split_idx = len(inner_batch)
                full_list = inner_batch + outer_batch
                self._log('Getting support between outer batch %d and inner'
                          'batch %d.' % (outer_idx, inner_idx))
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx)

            # Add all the new support links
            support_links |= (some_support_links - existing_links)

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(support_links) >= self.batch_size:
                self._log("Copying batch of %d support links into db." %
                          len(support_links))
                db.copy('pa_support_links', support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                gatherer.add('links', len(support_links))
                existing_links |= support_links
                support_links = set()

        # Insert any remaining support links.
        if support_links:
            self._log("Copying final batch of %d support links into db." %
                      len(support_links))
            db.copy('pa_support_links', support_links,
                    ('supported_mk_hash', 'supporting_mk_hash'))
            gatherer.add('links', len(support_links))

        # Delete the pickle cache
        if path.exists(sid_cache_fname):
            remove(sid_cache_fname)

        return True