Example #1
0
def test_split_idx():
    ras = Agent('RAS', db_refs={'FPLX': 'RAS'})
    kras = Agent('KRAS', db_refs={'HGNC': '6407'})
    hras = Agent('HRAS', db_refs={'HGNC': '5173'})
    st1 = Phosphorylation(Agent('x'), ras)
    st2 = Phosphorylation(Agent('x'), kras)
    st3 = Phosphorylation(Agent('x'), hras)
    pa = Preassembler(bio_ontology)
    maps = pa._generate_id_maps([st1, st2, st3])
    assert (1, 0) in maps, maps
    assert (2, 0) in maps, maps
    assert pa._comparison_counter == 2
    pa = Preassembler(bio_ontology)
    maps = pa._generate_id_maps([st1, st2, st3], split_idx=1)
    assert (2, 0) in maps, maps
    assert (1, 0) not in maps, maps
    assert pa._comparison_counter == 1

    # Test other endpoints
    refinements = pa._generate_relations([st1, st2, st3])
    assert refinements == \
        {st2.get_hash(): {st1.get_hash()},
         st3.get_hash(): {st1.get_hash()}}, refinements

    refinements = pa._generate_relation_tuples([st1, st2, st3])
    assert refinements == \
        {(st2.get_hash(), st1.get_hash()),
         (st3.get_hash(), st1.get_hash())}
Example #2
0
def test_split_idx():
    ras = Agent('RAS', db_refs={'FPLX': 'RAS'})
    kras = Agent('KRAS', db_refs={'HGNC': '6407'})
    hras = Agent('HRAS', db_refs={'HGNC': '5173'})
    st1 = Phosphorylation(Agent('x'), ras)
    st2 = Phosphorylation(Agent('x'), kras)
    st3 = Phosphorylation(Agent('x'), hras)
    pa = Preassembler(bio_ontology)
    maps = pa._generate_id_maps([st1, st2, st3])
    assert (1, 0) in maps, maps
    assert (2, 0) in maps, maps
    assert pa._comparison_counter == 2
    pa = Preassembler(bio_ontology)
    maps = pa._generate_id_maps([st1, st2, st3], split_idx=1)
    assert (2, 0) in maps, maps
    assert (1, 0) not in maps, maps
    assert pa._comparison_counter == 1
Example #3
0
class PreassemblyManager(object):
    """Class used to manage the preassembly pipeline

    Parameters
    ----------
    n_proc : int
        Select the number of processes that will be used when performing
        preassembly. Default is 1.
    batch_size : int
        Select the maximum number of statements you wish to be handled at a
        time. In general, a larger batch size will somewhat be faster, but
        require much more memory.
    """
    def __init__(self, n_proc=1, batch_size=10000):
        self.n_proc = n_proc
        self.batch_size = batch_size
        self.pa = Preassembler(hierarchies)
        self.__tag = 'Unpurposed'
        return

    def _get_latest_updatetime(self, db):
        """Get the date of the latest update."""
        update_list = db.select_all(db.PreassemblyUpdates)
        if not len(update_list):
            logger.warning("The preassembled corpus has not been initialized, "
                           "or else the updates table has not been populated.")
            return None
        return max([u.run_datetime for u in update_list])

    def _pa_batch_iter(self, db, in_mks=None, ex_mks=None):
        """Return an iterator over batches of preassembled statements.

        This avoids the need to load all such statements from the database into
        RAM at the same time (as this can be quite large).

        You may limit the set of pa_statements loaded by providing a set/list of
        matches-keys of the statements you wish to include.
        """
        if in_mks is None and ex_mks is None:
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         yield_per=self.batch_size)
        elif ex_mks is None and in_mks:
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         db.PAStatements.mk_hash.in_(in_mks),
                                         yield_per=self.batch_size)
        elif in_mks is None and ex_mks:
            db_stmt_iter = db.select_all(
                db.PAStatements.json,
                db.PAStatements.mk_hash.notin_(ex_mks),
                yield_per=self.batch_size)
        elif in_mks and ex_mks:
            db_stmt_iter = db.select_all(
                db.PAStatements.json,
                db.PAStatements.mk_hash.notin_(ex_mks),
                db.PAStatements.mk_hash.in_(in_mks),
                yield_per=self.batch_size)
        else:
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         yield_per=self.batch_size)

        pa_stmts = (_stmt_from_json(s_json) for s_json, in db_stmt_iter)
        return batch_iter(pa_stmts, self.batch_size, return_func=list)

    def _get_unique_statements(self, db, stmt_tpls, num_stmts, mk_done=None):
        """Get the unique Statements from the raw statements."""
        if mk_done is None:
            mk_done = set()

        new_mk_set = set()
        stmt_batches = batch_iter(stmt_tpls, self.batch_size, return_func=list)
        num_batches = num_stmts / self.batch_size
        for i, stmt_tpl_batch in enumerate(stmt_batches):
            self._log("Processing batch %d/%d of %d/%d statements." %
                      (i, num_batches, len(stmt_tpl_batch), num_stmts))
            unique_stmts, evidence_links = \
                self._make_unique_statement_set(stmt_tpl_batch)
            new_unique_stmts = []
            for s in unique_stmts:
                s_hash = s.get_hash(shallow=True)
                if s_hash not in (mk_done | new_mk_set):
                    new_mk_set.add(s_hash)
                    new_unique_stmts.append(s)
            insert_pa_stmts(db, new_unique_stmts)
            db.copy('raw_unique_links', evidence_links,
                    ('pa_stmt_mk_hash', 'raw_stmt_id'))
        self._log("Added %d new pa statements into the database." %
                  len(new_mk_set))
        return new_mk_set

    @_handle_update_table
    def create_corpus(self, db, continuing=False):
        """Initialize the table of preassembled statements.

        This method will find the set of unique knowledge represented in the
        table of raw statements, and it will populate the table of preassembled
        statements (PAStatements/pa_statements), while maintaining links between
        the raw statements and their unique (pa) counterparts. Furthermore, the
        refinement/support relationships between unique statements will be found
        and recorded in the PASupportLinks/pa_support_links table.

        For more detail on preassembly, see indra/preassembler/__init__.py
        """
        self.__tag = 'create'
        # Get the statements
        stmt_ids = distill_stmts(db, num_procs=self.n_proc)
        if continuing:
            self._log("Getting set of statements already de-duplicated...")
            checked_raw_stmt_ids, pa_stmt_hashes = \
                zip(*db.select_all([db.RawUniqueLinks.raw_stmt_uuid,
                                    db.RawUniqueLinks.pa_stmt_mk_hash]))
            stmt_ids -= set(checked_raw_stmt_ids)
            done_pa_ids = set(pa_stmt_hashes)
            self._log("Found %d preassembled statements already done." %
                      len(done_pa_ids))
        else:
            done_pa_ids = set()
        stmts = ((sid, _stmt_from_json(s_json)) for sid, s_json in
                 db.select_all([db.RawStatements.id, db.RawStatements.json],
                               db.RawStatements.id.in_(stmt_ids),
                               yield_per=self.batch_size))
        self._log("Found %d statements in all." % len(stmt_ids))

        # Get the set of unique statements
        if stmt_ids:
            self._get_unique_statements(db, stmts, len(stmt_ids), done_pa_ids)

        # If we are continuing, check for support links that were already found.
        if continuing:
            self._log("Getting pre-existing links...")
            db_existing_links = db.select_all([
                db.PASupportLinks.supporting_mk_hash,
                db.PASupportLinks.supporting_mk_hash
            ])
            existing_links = {tuple(res) for res in db_existing_links}
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()

        # Now get the support links between all batches.
        support_links = set()
        for outer_batch in self._pa_batch_iter(db):
            # Get internal support links
            some_support_links = self._get_support_links(outer_batch,
                                                         poolsize=self.n_proc)
            outer_mk_hashes = {s.get_hash(shallow=True) for s in outer_batch}

            # Get links with all other batches
            for inner_batch in self._pa_batch_iter(db, ex_mks=outer_mk_hashes):
                split_idx = len(inner_batch)
                full_list = inner_batch + outer_batch
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx,
                                            poolsize=self.n_proc)

            # Add all the new support links
            support_links |= (some_support_links - existing_links)

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(support_links) >= self.batch_size:
                self._log("Copying batch of %d support links into db." %
                          len(support_links))
                db.copy('pa_support_links', support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= support_links
                support_links = set()

        # Insert any remaining support links.
        if support_links:
            self._log("Copying final batch of %d support links into db." %
                      len(support_links))
            db.copy('pa_support_links', support_links,
                    ('supported_mk_hash', 'supporting_mk_hash'))

        return True

    def _get_new_stmt_ids(self, db):
        """Get all the uuids of statements not included in evidence."""
        old_id_q = db.filter_query(
            db.RawStatements.id,
            db.RawStatements.id == db.RawUniqueLinks.raw_stmt_id)
        new_sid_q = db.filter_query(db.RawStatements.id).except_(old_id_q)
        all_new_stmt_ids = {sid for sid, in new_sid_q.all()}
        self._log("Found %d new statement ids." % len(all_new_stmt_ids))
        return all_new_stmt_ids

    @_handle_update_table
    def supplement_corpus(self, db, continuing=False):
        """Update the table of preassembled statements.

        This method will take any new raw statements that have not yet been
        incorporated into the preassembled table, and use them to augment the
        preassembled table.

        The resulting updated table is indistinguishable from the result you
        would achieve if you had simply re-run preassembly on _all_ the
        raw statements.
        """
        self.__tag = 'supplement'
        last_update = self._get_latest_updatetime(db)
        self._log("Latest update was: %s" % last_update)

        # Get the new statements...
        self._log("Loading info about the existing state of preassembly. "
                  "(This may take a little time)")
        new_ids = self._get_new_stmt_ids(db)

        # If we are continuing, check for support links that were already found.
        if continuing:
            self._log("Getting pre-existing links...")
            db_existing_links = db.select_all([
                db.PASupportLinks.supporting_mk_hash,
                db.PASupportLinks.supporting_mk_hash
            ])
            existing_links = {tuple(res) for res in db_existing_links}
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()

        # Weed out exact duplicates.
        stmt_ids = distill_stmts(db, num_procs=self.n_proc)
        new_stmt_ids = new_ids & stmt_ids
        self._log("There are %d new distilled raw statement ids." %
                  len(new_stmt_ids))
        new_stmts = ((sid, _stmt_from_json(s_json))
                     for sid, s_json in db.select_all(
                         [db.RawStatements.id, db.RawStatements.json],
                         db.RawStatements.id.in_(new_stmt_ids),
                         yield_per=self.batch_size))

        # Get the set of new unique statements and link to any new evidence.
        old_mk_set = {mk for mk, in db.select_all(db.PAStatements.mk_hash)}
        self._log("Found %d old pa statements." % len(old_mk_set))
        new_mk_set = self._get_unique_statements(db, new_stmts,
                                                 len(new_stmt_ids), old_mk_set)
        self._log("Found %d new pa statements." % len(new_mk_set))

        # Now find the new support links that need to be added.
        new_support_links = set()
        for npa_batch in self._pa_batch_iter(db, in_mks=new_mk_set):
            some_support_links = set()

            # Compare internally
            some_support_links |= self._get_support_links(npa_batch)

            # Compare against the other new batch statements.
            diff_new_mks = new_mk_set - {
                s.get_hash(shallow=True)
                for s in npa_batch
            }
            for diff_npa_batch in self._pa_batch_iter(db, in_mks=diff_new_mks):
                split_idx = len(npa_batch)
                full_list = npa_batch + diff_npa_batch
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx,
                                            poolsize=self.n_proc)

            # Compare against the existing statements.
            for opa_batch in self._pa_batch_iter(db, in_mks=old_mk_set):
                split_idx = len(npa_batch)
                full_list = npa_batch + opa_batch
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx,
                                            poolsize=self.n_proc)

            new_support_links |= (some_support_links - existing_links)

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(new_support_links) >= self.batch_size:
                self._log("Copying batch of %d support links into db." %
                          len(new_support_links))
                db.copy('pa_support_links', new_support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= new_support_links
                new_support_links = set()

        # Insert any remaining support links.
        if new_support_links:
            self._log("Copying batch final of %d support links into db." %
                      len(new_support_links))
            db.copy('pa_support_links', new_support_links,
                    ('supported_mk_hash', 'supporting_mk_hash'))
            existing_links |= new_support_links

        return True

    def _log(self, msg, level='info'):
        """Applies a task specific tag to the log message."""
        getattr(logger, level)("(%s) %s" % (self.__tag, msg))

    def _make_unique_statement_set(self, stmt_tpls):
        """Perform grounding, sequence mapping, and find unique set from stmts.

        This method returns a list of statement objects, as well as a set of
        tuples of the form (uuid, matches_key) which represent the links between
        raw (evidence) statements and their unique/preassembled counterparts.
        """
        stmts = []
        uuid_sid_dict = {}
        for sid, stmt in stmt_tpls:
            uuid_sid_dict[stmt.uuid] = sid
            stmts.append(stmt)
        stmts = ac.map_grounding(stmts)
        stmts = ac.map_sequence(stmts)
        stmt_groups = self.pa._get_stmt_matching_groups(stmts)
        unique_stmts = []
        evidence_links = defaultdict(lambda: set())
        for _, duplicates in stmt_groups:
            # Get the first statement and add the evidence of all subsequent
            # Statements to it
            for stmt_ix, stmt in enumerate(duplicates):
                if stmt_ix == 0:
                    first_stmt = stmt.make_generic_copy()
                    stmt_hash = first_stmt.get_hash(shallow=True)
                evidence_links[stmt_hash].add(uuid_sid_dict[stmt.uuid])
            # This should never be None or anything else
            assert isinstance(first_stmt, type(stmt))
            unique_stmts.append(first_stmt)
        return unique_stmts, flatten_evidence_dict(evidence_links)

    def _get_support_links(self, unique_stmts, **generate_id_map_kwargs):
        """Find the links of refinement/support between statements."""
        id_maps = self.pa._generate_id_maps(unique_stmts,
                                            **generate_id_map_kwargs)
        return {
            tuple(
                [unique_stmts[idx].get_hash(shallow=True) for idx in idx_pair])
            for idx_pair in id_maps
        }
class PreassemblyManager(object):
    """Class used to manage the preassembly pipeline

    Parameters
    ----------
    n_proc : int
        Select the number of processes that will be used when performing
        preassembly. Default is 1.
    batch_size : int
        Select the maximum number of statements you wish to be handled at a
        time. In general, a larger batch size will somewhat be faster, but
        require much more memory.
    """
    def __init__(self, n_proc=1, batch_size=10000, print_logs=False):
        self.n_proc = n_proc
        self.batch_size = batch_size
        self.pa = Preassembler(hierarchies)
        self.__tag = 'Unpurposed'
        self.__print_logs = print_logs
        return

    def _get_latest_updatetime(self, db):
        """Get the date of the latest update."""
        update_list = db.select_all(db.PreassemblyUpdates)
        if not len(update_list):
            logger.warning("The preassembled corpus has not been initialized, "
                           "or else the updates table has not been populated.")
            return None
        return max([u.run_datetime for u in update_list])

    def _pa_batch_iter(self, db, in_mks=None, ex_mks=None):
        """Return an iterator over batches of preassembled statements.

        This avoids the need to load all such statements from the database into
        RAM at the same time (as this can be quite large).

        You may limit the set of pa_statements loaded by providing a set/list of
        matches-keys of the statements you wish to include.
        """
        if in_mks is None and ex_mks is None:
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         yield_per=self.batch_size)
        elif ex_mks is None and in_mks is not None:
            if not in_mks:
                return []
            db_stmt_iter = db.select_all(db.PAStatements.json,
                                         db.PAStatements.mk_hash.in_(in_mks),
                                         yield_per=self.batch_size)
        elif in_mks is None and ex_mks is not None:
            if not ex_mks:
                return []
            db_stmt_iter = db.select_all(
                db.PAStatements.json,
                db.PAStatements.mk_hash.notin_(ex_mks),
                yield_per=self.batch_size)
        elif in_mks and ex_mks:
            db_stmt_iter = db.select_all(
                db.PAStatements.json,
                db.PAStatements.mk_hash.notin_(ex_mks),
                db.PAStatements.mk_hash.in_(in_mks),
                yield_per=self.batch_size)
        else:  # Neither is None, and both are empty.
            return []

        pa_stmts = (_stmt_from_json(s_json) for s_json, in db_stmt_iter)
        return batch_iter(pa_stmts, self.batch_size, return_func=list)

    def _raw_sid_stmt_iter(self, db, id_set, do_enumerate=False):
        """Return a generator over statements with the given database ids."""
        i = 0
        for stmt_id_batch in batch_iter(id_set, self.batch_size):
            subres = db.select_all(
                [db.RawStatements.id, db.RawStatements.json],
                db.RawStatements.id.in_(stmt_id_batch),
                yield_per=self.batch_size // 10)
            if do_enumerate:
                yield i, [(sid, _stmt_from_json(s_json))
                          for sid, s_json in subres]
                i += 1
            else:
                yield [(sid, _stmt_from_json(s_json))
                       for sid, s_json in subres]

    @clockit
    def _get_unique_statements(self, db, raw_sids, num_stmts, mk_done=None):
        """Get the unique Statements from the raw statements."""
        self._log("There are %d distilled raw statement ids to preassemble." %
                  len(raw_sids))

        if mk_done is None:
            mk_done = set()

        new_mk_set = set()
        num_batches = num_stmts / self.batch_size
        for i, stmt_tpl_batch in self._raw_sid_stmt_iter(db, raw_sids, True):
            self._log("Processing batch %d/%d of %d/%d statements." %
                      (i, num_batches, len(stmt_tpl_batch), num_stmts))
            # Get a list of statements, and generate a mapping from uuid to sid.
            stmts = []
            uuid_sid_dict = {}
            for sid, stmt in stmt_tpl_batch:
                uuid_sid_dict[stmt.uuid] = sid
                stmts.append(stmt)

            # Map groundings and sequences.
            cleaned_stmts = self._clean_statements(stmts)

            # Use the shallow hash to condense unique statements.
            new_unique_stmts, evidence_links = \
                self._condense_statements(cleaned_stmts, mk_done, new_mk_set,
                                          uuid_sid_dict)

            self._log("Insert new statements into database...")
            insert_pa_stmts(db, new_unique_stmts)
            self._log("Insert new raw_unique links into the database...")
            db.copy('raw_unique_links', flatten_evidence_dict(evidence_links),
                    ('pa_stmt_mk_hash', 'raw_stmt_id'))
        self._log("Added %d new pa statements into the database." %
                  len(new_mk_set))
        return new_mk_set

    @clockit
    def _condense_statements(self, cleaned_stmts, mk_done, new_mk_set,
                             uuid_sid_dict):
        self._log("Condense into unique statements...")
        new_unique_stmts = []
        evidence_links = defaultdict(lambda: set())
        for s in cleaned_stmts:
            h = shash(s)

            # If this statement is new, make it.
            if h not in mk_done and h not in new_mk_set:
                new_unique_stmts.append(s.make_generic_copy())
                new_mk_set.add(h)

            # Add the evidence to the dict.
            evidence_links[h].add(uuid_sid_dict[s.uuid])
        return new_unique_stmts, evidence_links

    @_handle_update_table
    def create_corpus(self, db, continuing=False, dups_file=None):
        """Initialize the table of preassembled statements.

        This method will find the set of unique knowledge represented in the
        table of raw statements, and it will populate the table of preassembled
        statements (PAStatements/pa_statements), while maintaining links between
        the raw statements and their unique (pa) counterparts. Furthermore, the
        refinement/support relationships between unique statements will be found
        and recorded in the PASupportLinks/pa_support_links table.

        For more detail on preassembly, see indra/preassembler/__init__.py
        """
        self.__tag = 'create'

        dup_handling = dups_file if dups_file else 'delete'

        # Get filtered statement ID's.
        sid_cache_fname = path.join(HERE, 'stmt_id_cache.pkl')
        if continuing and path.exists(sid_cache_fname):
            with open(sid_cache_fname, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            # Get the statement ids.
            stmt_ids = distill_stmts(db,
                                     num_procs=self.n_proc,
                                     handle_duplicates=dup_handling)
            with open(sid_cache_fname, 'wb') as f:
                pickle.dump(stmt_ids, f)

        # Handle the possibility we're picking up after an earlier job...
        done_pa_ids = set()
        if continuing:
            self._log("Getting set of statements already de-duplicated...")
            link_resp = db.select_all([
                db.RawUniqueLinks.raw_stmt_id,
                db.RawUniqueLinks.pa_stmt_mk_hash
            ])
            if link_resp:
                checked_raw_stmt_ids, pa_stmt_hashes = \
                    zip(*db.select_all([db.RawUniqueLinks.raw_stmt_id,
                                        db.RawUniqueLinks.pa_stmt_mk_hash]))
                stmt_ids -= set(checked_raw_stmt_ids)
                done_pa_ids = set(pa_stmt_hashes)
                self._log("Found %d preassembled statements already done." %
                          len(done_pa_ids))

        # Get the set of unique statements
        self._get_unique_statements(db, stmt_ids, len(stmt_ids), done_pa_ids)

        # If we are continuing, check for support links that were already found.
        if continuing:
            self._log("Getting pre-existing links...")
            db_existing_links = db.select_all([
                db.PASupportLinks.supporting_mk_hash,
                db.PASupportLinks.supporting_mk_hash
            ])
            existing_links = {tuple(res) for res in db_existing_links}
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()

        # Now get the support links between all batches.
        support_links = set()
        for i, outer_batch in enumerate(self._pa_batch_iter(db)):
            # Get internal support links
            self._log('Getting internal support links outer batch %d.' % i)
            some_support_links = self._get_support_links(outer_batch,
                                                         poolsize=self.n_proc)
            outer_mk_hashes = {shash(s) for s in outer_batch}

            # Get links with all other batches
            ib_iter = self._pa_batch_iter(db, ex_mks=outer_mk_hashes)
            for j, inner_batch in enumerate(ib_iter):
                split_idx = len(inner_batch)
                full_list = inner_batch + outer_batch
                self._log('Getting support compared to other batch %d of outer'
                          'batch %d.' % (j, i))
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx,
                                            poolsize=self.n_proc)

            # Add all the new support links
            support_links |= (some_support_links - existing_links)

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(support_links) >= self.batch_size:
                self._log("Copying batch of %d support links into db." %
                          len(support_links))
                db.copy('pa_support_links', support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= support_links
                support_links = set()

        # Insert any remaining support links.
        if support_links:
            self._log("Copying final batch of %d support links into db." %
                      len(support_links))
            db.copy('pa_support_links', support_links,
                    ('supported_mk_hash', 'supporting_mk_hash'))

        # Delete the pickle cache
        if path.exists(sid_cache_fname):
            remove(sid_cache_fname)

        return True

    def _get_new_stmt_ids(self, db):
        """Get all the uuids of statements not included in evidence."""
        old_id_q = db.filter_query(
            db.RawStatements.id,
            db.RawStatements.id == db.RawUniqueLinks.raw_stmt_id)
        new_sid_q = db.filter_query(db.RawStatements.id).except_(old_id_q)
        all_new_stmt_ids = {sid for sid, in new_sid_q.all()}
        self._log("Found %d new statement ids." % len(all_new_stmt_ids))
        return all_new_stmt_ids

    @_handle_update_table
    def supplement_corpus(self, db, continuing=False, dups_file=None):
        """Update the table of preassembled statements.

        This method will take any new raw statements that have not yet been
        incorporated into the preassembled table, and use them to augment the
        preassembled table.

        The resulting updated table is indistinguishable from the result you
        would achieve if you had simply re-run preassembly on _all_ the
        raw statements.
        """
        self.__tag = 'supplement'

        dup_handling = dups_file if dups_file else 'delete'

        pickle_stashes = []
        last_update = self._get_latest_updatetime(db)
        start_date = datetime.utcnow()
        self._log("Latest update was: %s" % last_update)

        # Get the new statements...
        self._log("Loading info about the existing state of preassembly. "
                  "(This may take a little time)")
        new_id_stash = 'new_ids.pkl'
        pickle_stashes.append(new_id_stash)
        if continuing and path.exists(new_id_stash):
            self._log("Loading new statement ids from cache...")
            with open(new_id_stash, 'rb') as f:
                new_ids = pickle.load(f)
        else:
            new_ids = self._get_new_stmt_ids(db)

            # Stash the new ids in case we need to pick up where we left off.
            with open(new_id_stash, 'wb') as f:
                pickle.dump(new_ids, f)

        # Weed out exact duplicates.
        dist_stash = 'stmt_ids.pkl'
        pickle_stashes.append(dist_stash)
        if continuing and path.exists(dist_stash):
            self._log("Loading distilled statement ids from cache...")
            with open(dist_stash, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            stmt_ids = distill_stmts(db,
                                     num_procs=self.n_proc,
                                     get_full_stmts=False,
                                     handle_duplicates=dup_handling)
            with open(dist_stash, 'wb') as f:
                pickle.dump(stmt_ids, f)

        new_stmt_ids = new_ids & stmt_ids

        # Get the set of new unique statements and link to any new evidence.
        old_mk_set = {mk for mk, in db.select_all(db.PAStatements.mk_hash)}
        self._log("Found %d old pa statements." % len(old_mk_set))
        new_mk_stash = 'new_mk_set.pkl'
        pickle_stashes.append(new_mk_stash)
        if continuing and path.exists(new_mk_stash):
            self._log("Loading hashes for new pa statements from cache...")
            with open(new_mk_stash, 'rb') as f:
                stash_dict = pickle.load(f)
            start_date = stash_dict['start']
            end_date = stash_dict['end']
            new_mk_set = stash_dict['mk_set']
        else:
            new_mk_set = self._get_unique_statements(db, new_stmt_ids,
                                                     len(new_stmt_ids),
                                                     old_mk_set)
            end_date = datetime.utcnow()
            with open(new_mk_stash, 'wb') as f:
                pickle.dump(
                    {
                        'start': start_date,
                        'end': end_date,
                        'mk_set': new_mk_set
                    }, f)
        if continuing:
            self._log("Original old mk set: %d" % len(old_mk_set))
            old_mk_set = old_mk_set - new_mk_set
            self._log("Adjusted old mk set: %d" % len(old_mk_set))

        self._log("Found %d new pa statements." % len(new_mk_set))

        # If we are continuing, check for support links that were already found.
        support_link_stash = 'new_support_links.pkl'
        pickle_stashes.append(support_link_stash)
        if continuing and path.exists(support_link_stash):
            with open(support_link_stash, 'rb') as f:
                status_dict = pickle.load(f)
                existing_links = status_dict['existing links']
                npa_done = status_dict['ids done']
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()
            npa_done = set()

        # Now find the new support links that need to be added.
        new_support_links = set()
        batching_args = (self.batch_size, db.PAStatements.json,
                         db.PAStatements.create_date >= start_date,
                         db.PAStatements.create_date <= end_date)
        npa_json_iter = db.select_all_batched(
            *batching_args, order_by=db.PAStatements.create_date)
        try:
            for outer_offset, npa_json_batch in npa_json_iter:
                npa_batch = [
                    _stmt_from_json(s_json) for s_json, in npa_json_batch
                ]

                # Compare internally
                self._log("Getting support for new pa at offset %d." %
                          outer_offset)
                some_support_links = self._get_support_links(npa_batch)

                # Compare against the other new batch statements.
                diff_new_mks = new_mk_set - {shash(s) for s in npa_batch}
                other_npa_json_iter = db.select_all_batched(
                    *batching_args,
                    order_by=db.PAStatements.create_date,
                    skip_offset=outer_offset)
                for inner_offset, other_npa_json_batch in other_npa_json_iter:
                    other_npa_batch = [
                        _stmt_from_json(s_json)
                        for s_json, in other_npa_json_batch
                    ]
                    split_idx = len(npa_batch)
                    full_list = npa_batch + other_npa_batch
                    self._log("Comparing offset %d to offset %d of other new "
                              "statements." % (outer_offset, inner_offset))
                    some_support_links |= \
                        self._get_support_links(full_list, split_idx=split_idx,
                                                poolsize=self.n_proc)

                # Compare against the existing statements.
                opa_json_iter = db.select_all_batched(
                    self.batch_size, db.PAStatements.json,
                    db.PAStatements.create_date < start_date)
                for old_offset, opa_json_batch in opa_json_iter:
                    opa_batch = [
                        _stmt_from_json(s_json) for s_json, in opa_json_batch
                    ]
                    split_idx = len(npa_batch)
                    full_list = npa_batch + opa_batch
                    self._log("Comparing new offset %d to offset %d of old "
                              "statements." % (outer_offset, old_offset))
                    some_support_links |= \
                        self._get_support_links(full_list, split_idx=split_idx,
                                                poolsize=self.n_proc)

                # Although there are generally few support links, copying as we
                # go allows work to not be wasted.
                new_support_links |= (some_support_links - existing_links)
                self._log("Copying batch of %d support links into db." %
                          len(new_support_links))
                db.copy('pa_support_links', new_support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= new_support_links
                npa_done |= {s.get_hash(shallow=True) for s in npa_batch}
                new_support_links = set()
                with open(support_link_stash, 'wb') as f:
                    pickle.dump(
                        {
                            'existing links': existing_links,
                            'ids done': npa_done
                        }, f)

            # Insert any remaining support links.
            if new_support_links:
                self._log("Copying batch final of %d support links into db." %
                          len(new_support_links))
                db.copy('pa_support_links', new_support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                existing_links |= new_support_links
        except Exception:
            logger.info("Stashing support links found so far.")
            if new_support_links:
                with open(support_link_stash, 'wb') as f:
                    pickle.dump(existing_links, f)
            raise

        # Remove all the caches so they can't be picked up accidentally later.
        for cache in pickle_stashes:
            if path.exists(cache):
                remove(cache)

        return True

    def _log(self, msg, level='info'):
        """Applies a task specific tag to the log message."""
        if self.__print_logs:
            print("Preassembly Manager [%s] (%s): %s" %
                  (datetime.now(), self.__tag, msg))
        getattr(logger, level)("(%s) %s" % (self.__tag, msg))

    @clockit
    def _clean_statements(self, stmts):
        """Perform grounding, sequence mapping, and find unique set from stmts.

        This method returns a list of statement objects, as well as a set of
        tuples of the form (uuid, matches_key) which represent the links between
        raw (evidence) statements and their unique/preassembled counterparts.
        """
        self._log("Map grounding...")
        stmts = ac.map_grounding(stmts)
        self._log("Map sequences...")
        stmts = ac.map_sequence(stmts, use_cache=True)
        return stmts

    @clockit
    def _get_support_links(self, unique_stmts, **generate_id_map_kwargs):
        """Find the links of refinement/support between statements."""
        id_maps = self.pa._generate_id_maps(unique_stmts,
                                            **generate_id_map_kwargs)
        ret = set()
        for ix_pair in id_maps:
            if ix_pair[0] == ix_pair[1]:
                assert False, "Self-comparison occurred."
            hash_pair = \
                tuple([shash(unique_stmts[ix]) for ix in ix_pair])
            if hash_pair[0] == hash_pair[1]:
                assert False, "Input list included duplicates."
            ret.add(hash_pair)

        return ret
Example #5
0
class PreassemblyManager(object):
    """Class used to manage the preassembly pipeline

    Parameters
    ----------
    n_proc : int
        Select the number of processes that will be used when performing
        preassembly. Default is 1.
    batch_size : int
        Select the maximum number of statements you wish to be handled at a
        time. In general, a larger batch size will somewhat be faster, but
        require much more memory.
    """
    def __init__(self, n_proc=1, batch_size=10000, print_logs=False):
        self.n_proc = n_proc
        self.batch_size = batch_size
        self.pa = Preassembler(bio_ontology)
        self.__tag = 'Unpurposed'
        self.__print_logs = print_logs
        self.pickle_stashes = None
        return

    def _get_latest_updatetime(self, db):
        """Get the date of the latest update."""
        update_list = db.select_all(db.PreassemblyUpdates)
        if not len(update_list):
            logger.warning("The preassembled corpus has not been initialized, "
                           "or else the updates table has not been populated.")
            return None
        return max([u.run_datetime for u in update_list])

    def _raw_sid_stmt_iter(self, db, id_set, do_enumerate=False):
        """Return a generator over statements with the given database ids."""
        def _fixed_raw_stmt_from_json(s_json, tr):
            stmt = _stmt_from_json(s_json)
            if tr is not None:
                stmt.evidence[0].pmid = tr.pmid
                stmt.evidence[0].text_refs = {
                    k: v
                    for k, v in tr.__dict__.items() if not k.startswith('_')
                }
            return stmt

        i = 0
        for stmt_id_batch in batch_iter(id_set, self.batch_size):
            subres = (db.filter_query(
                [db.RawStatements.id, db.RawStatements.json, db.TextRef],
                db.RawStatements.id.in_(stmt_id_batch)).outerjoin(
                    db.Reading).outerjoin(db.TextContent).outerjoin(
                        db.TextRef).yield_per(self.batch_size // 10))
            data = [(sid, _fixed_raw_stmt_from_json(s_json, tr))
                    for sid, s_json, tr in subres]
            if do_enumerate:
                yield i, data
                i += 1
            else:
                yield data

    @clockit
    def _extract_and_push_unique_statements(self,
                                            db,
                                            raw_sids,
                                            num_stmts,
                                            mk_done=None):
        """Get the unique Statements from the raw statements."""
        self._log("There are %d distilled raw statement ids to preassemble." %
                  len(raw_sids))

        if mk_done is None:
            mk_done = set()

        new_mk_set = set()
        num_batches = num_stmts / self.batch_size
        for i, stmt_tpl_batch in self._raw_sid_stmt_iter(db, raw_sids, True):
            self._log("Processing batch %d/%d of %d/%d statements." %
                      (i, num_batches, len(stmt_tpl_batch), num_stmts))

            # Get a list of statements and generate a mapping from uuid to sid.
            stmts = []
            uuid_sid_dict = {}
            for sid, stmt in stmt_tpl_batch:
                uuid_sid_dict[stmt.uuid] = sid
                stmts.append(stmt)

            # Map groundings and sequences.
            cleaned_stmts, eliminated_uuids = self._clean_statements(stmts)
            discarded_stmts = [
                (uuid_sid_dict[uuid], reason)
                for reason, uuid_set in eliminated_uuids.items()
                for uuid in uuid_set
            ]
            db.copy('discarded_statements',
                    discarded_stmts, ('stmt_id', 'reason'),
                    commit=False)

            # Use the shallow hash to condense unique statements.
            new_unique_stmts, evidence_links, agent_tuples = \
                self._condense_statements(cleaned_stmts, mk_done, new_mk_set,
                                          uuid_sid_dict)

            # Insert the statements and their links.
            self._log("Insert new statements into database...")
            insert_pa_stmts(db,
                            new_unique_stmts,
                            ignore_agents=True,
                            commit=False)
            gatherer.add('stmts', len(new_unique_stmts))

            self._log("Insert new raw_unique links into the database...")
            ev_links = flatten_evidence_dict(evidence_links)
            db.copy('raw_unique_links',
                    ev_links, ('pa_stmt_mk_hash', 'raw_stmt_id'),
                    commit=False)
            gatherer.add('evidence', len(ev_links))

            db.copy_lazy('pa_agents',
                         hash_pa_agents(agent_tuples),
                         ('stmt_mk_hash', 'ag_num', 'db_name', 'db_id', 'role',
                          'agent_ref_hash'),
                         commit=False)
            insert_pa_agents(db,
                             new_unique_stmts,
                             verbose=True,
                             skip=['agents'])  # This will commit

        self._log("Added %d new pa statements into the database." %
                  len(new_mk_set))
        return new_mk_set

    @clockit
    def _condense_statements(self, cleaned_stmts, mk_done, new_mk_set,
                             uuid_sid_dict):
        self._log("Condense into unique statements...")
        new_unique_stmts = []
        evidence_links = defaultdict(lambda: set())
        agent_tuples = set()
        for s in cleaned_stmts:
            h = s.get_hash(refresh=True)

            # If this statement is new, make it.
            if h not in mk_done and h not in new_mk_set:
                new_unique_stmts.append(s.make_generic_copy())
                new_mk_set.add(h)

            # Add the evidence to the dict.
            evidence_links[h].add(uuid_sid_dict[s.uuid])

            # Add any db refs to the agents.
            ref_data, _, _ = extract_agent_data(s, h)
            agent_tuples |= set(ref_data)

        return new_unique_stmts, evidence_links, agent_tuples

    @_handle_update_table
    @DGContext.wrap(gatherer)
    def create_corpus(self, db, continuing=False):
        """Initialize the table of preassembled statements.

        This method will find the set of unique knowledge represented in the
        table of raw statements, and it will populate the table of preassembled
        statements (PAStatements/pa_statements), while maintaining links between
        the raw statements and their unique (pa) counterparts. Furthermore, the
        refinement/support relationships between unique statements will be found
        and recorded in the PASupportLinks/pa_support_links table.

        For more detail on preassembly, see indra/preassembler/__init__.py
        """
        self.__tag = 'create'

        if not continuing:
            # Make sure the discarded statements table is cleared.
            db.drop_tables([db.DiscardedStatements])
            db.create_tables([db.DiscardedStatements])
            db.session.close()
            db.grab_session()
        else:
            # Get discarded statements
            skip_ids = {
                i
                for i, in db.select_all(db.DiscardedStatements.stmt_id)
            }
            self._log("Found %d discarded statements from earlier run." %
                      len(skip_ids))

        # Get filtered statement ID's.
        sid_cache_fname = path.join(HERE, 'stmt_id_cache.pkl')
        if continuing and path.exists(sid_cache_fname):
            with open(sid_cache_fname, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            # Get the statement ids.
            stmt_ids = distill_stmts(db)
            with open(sid_cache_fname, 'wb') as f:
                pickle.dump(stmt_ids, f)

        # Handle the possibility we're picking up after an earlier job...
        done_pa_ids = set()
        if continuing:
            self._log("Getting set of statements already de-duplicated...")
            link_resp = db.select_all([
                db.RawUniqueLinks.raw_stmt_id,
                db.RawUniqueLinks.pa_stmt_mk_hash
            ])
            if link_resp:
                checked_raw_stmt_ids, pa_stmt_hashes = \
                    zip(*db.select_all([db.RawUniqueLinks.raw_stmt_id,
                                        db.RawUniqueLinks.pa_stmt_mk_hash]))
                stmt_ids -= set(checked_raw_stmt_ids)
                self._log("Found %d raw statements without links to unique." %
                          len(stmt_ids))
                stmt_ids -= skip_ids
                self._log("Found %d raw statements that still need to be "
                          "processed." % len(stmt_ids))
                done_pa_ids = set(pa_stmt_hashes)
                self._log("Found %d preassembled statements already done." %
                          len(done_pa_ids))

        # Get the set of unique statements
        self._extract_and_push_unique_statements(db, stmt_ids, len(stmt_ids),
                                                 done_pa_ids)

        # If we are continuing, check for support links that were already found
        if continuing:
            self._log("Getting pre-existing links...")
            db_existing_links = db.select_all([
                db.PASupportLinks.supporting_mk_hash,
                db.PASupportLinks.supporting_mk_hash
            ])
            existing_links = {tuple(res) for res in db_existing_links}
            self._log("Found %d existing links." % len(existing_links))
        else:
            existing_links = set()

        # Now get the support links between all batches.
        support_links = set()
        outer_iter = db.select_all_batched(self.batch_size,
                                           db.PAStatements.json,
                                           order_by=db.PAStatements.mk_hash)
        for outer_idx, outer_batch_jsons in outer_iter:
            outer_batch = [_stmt_from_json(sj) for sj, in outer_batch_jsons]
            # Get internal support links
            self._log('Getting internal support links outer batch %d.' %
                      outer_idx)
            some_support_links = self._get_support_links(outer_batch)

            # Get links with all other batches
            inner_iter = db.select_all_batched(
                self.batch_size,
                db.PAStatements.json,
                order_by=db.PAStatements.mk_hash,
                skip_idx=outer_idx)
            for inner_idx, inner_batch_jsons in inner_iter:
                inner_batch = [
                    _stmt_from_json(sj) for sj, in inner_batch_jsons
                ]
                split_idx = len(inner_batch)
                full_list = inner_batch + outer_batch
                self._log('Getting support between outer batch %d and inner'
                          'batch %d.' % (outer_idx, inner_idx))
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx)

            # Add all the new support links
            support_links |= (some_support_links - existing_links)

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(support_links) >= self.batch_size:
                self._log("Copying batch of %d support links into db." %
                          len(support_links))
                db.copy('pa_support_links', support_links,
                        ('supported_mk_hash', 'supporting_mk_hash'))
                gatherer.add('links', len(support_links))
                existing_links |= support_links
                support_links = set()

        # Insert any remaining support links.
        if support_links:
            self._log("Copying final batch of %d support links into db." %
                      len(support_links))
            db.copy('pa_support_links', support_links,
                    ('supported_mk_hash', 'supporting_mk_hash'))
            gatherer.add('links', len(support_links))

        # Delete the pickle cache
        if path.exists(sid_cache_fname):
            remove(sid_cache_fname)

        return True

    def _get_new_stmt_ids(self, db):
        """Get all the uuids of statements not included in evidence."""
        old_id_q = db.filter_query(
            db.RawStatements.id,
            db.RawStatements.id == db.RawUniqueLinks.raw_stmt_id)
        new_sid_q = db.filter_query(db.RawStatements.id).except_(old_id_q)
        all_new_stmt_ids = {sid for sid, in new_sid_q.all()}
        self._log("Found %d new statement ids." % len(all_new_stmt_ids))
        return all_new_stmt_ids

    def _supplement_statements(self, db, continuing=False):
        """Supplement the preassembled statements with the latest content."""
        self.__tag = 'supplement'

        last_update = self._get_latest_updatetime(db)
        assert last_update is not None, \
            "The preassembly tables have not yet been initialized."
        start_date = datetime.utcnow()
        self._log("Latest update was: %s" % last_update)

        # Get the new statements...
        self._log("Loading info about the existing state of preassembly. "
                  "(This may take a little time)")
        new_id_stash = 'new_ids.pkl'
        self.pickle_stashes.append(new_id_stash)
        if continuing and path.exists(new_id_stash):
            self._log("Loading new statement ids from cache...")
            with open(new_id_stash, 'rb') as f:
                new_ids = pickle.load(f)
        else:
            new_ids = self._get_new_stmt_ids(db)

            # Stash the new ids in case we need to pick up where we left off.
            with open(new_id_stash, 'wb') as f:
                pickle.dump(new_ids, f)

        # Weed out exact duplicates.
        dist_stash = 'stmt_ids.pkl'
        self.pickle_stashes.append(dist_stash)
        if continuing and path.exists(dist_stash):
            self._log("Loading distilled statement ids from cache...")
            with open(dist_stash, 'rb') as f:
                stmt_ids = pickle.load(f)
        else:
            stmt_ids = distill_stmts(db, get_full_stmts=False)
            with open(dist_stash, 'wb') as f:
                pickle.dump(stmt_ids, f)

        # Get discarded statements
        skip_ids = {i for i, in db.select_all(db.DiscardedStatements.stmt_id)}

        # Select only the good new statement ids.
        new_stmt_ids = new_ids & stmt_ids - skip_ids

        # Get the set of new unique statements and link to any new evidence.
        old_mk_set = {mk for mk, in db.select_all(db.PAStatements.mk_hash)}
        self._log("Found %d old pa statements." % len(old_mk_set))
        new_mk_stash = 'new_mk_set.pkl'
        self.pickle_stashes.append(new_mk_stash)
        if continuing and path.exists(new_mk_stash):
            self._log("Loading hashes for new pa statements from cache...")
            with open(new_mk_stash, 'rb') as f:
                stash_dict = pickle.load(f)
            start_date = stash_dict['start']
            end_date = stash_dict['end']
            new_mk_set = stash_dict['mk_set']
        else:
            new_mk_set = self._extract_and_push_unique_statements(
                db, new_stmt_ids, len(new_stmt_ids), old_mk_set)
            end_date = datetime.utcnow()
            with open(new_mk_stash, 'wb') as f:
                pickle.dump(
                    {
                        'start': start_date,
                        'end': end_date,
                        'mk_set': new_mk_set
                    }, f)
        if continuing:
            self._log("Original old mk set: %d" % len(old_mk_set))
            old_mk_set = old_mk_set - new_mk_set
            self._log("Adjusted old mk set: %d" % len(old_mk_set))

        self._log("Found %d new pa statements." % len(new_mk_set))
        self.__tag = 'Unpurposed'
        return start_date, end_date

    def _supplement_support(self, db, start_date, end_date, continuing=False):
        """Calculate the support for the given date range of pa statements."""
        self.__tag = 'supplement'

        # If we are continuing, check for support links that were already found
        support_link_stash = 'new_support_links.pkl'
        self.pickle_stashes.append(support_link_stash)
        if continuing and path.exists(support_link_stash):
            with open(support_link_stash, 'rb') as f:
                status_dict = pickle.load(f)
                new_support_links = status_dict['existing links']
                npa_done = status_dict['ids done']
            self._log("Found %d previously found new links." %
                      len(new_support_links))
        else:
            new_support_links = set()
            npa_done = set()

        self._log("Downloading all pre-existing support links")
        existing_links = {(a, b)
                          for a, b in db.select_all([
                              db.PASupportLinks.supported_mk_hash,
                              db.PASupportLinks.supporting_mk_hash
                          ])}
        # Just in case...
        new_support_links -= existing_links

        # Now find the new support links that need to be added.
        batching_args = (self.batch_size, db.PAStatements.json,
                         db.PAStatements.create_date >= start_date,
                         db.PAStatements.create_date <= end_date)
        npa_json_iter = db.select_all_batched(*batching_args,
                                              order_by=db.PAStatements.mk_hash)
        for outer_idx, npa_json_batch in npa_json_iter:
            # Create the statements from the jsons.
            npa_batch = []
            for s_json, in npa_json_batch:
                s = _stmt_from_json(s_json)
                if s.get_hash(shallow=True) not in npa_done:
                    npa_batch.append(s)

            # Compare internally
            self._log("Getting support for new pa batch %d." % outer_idx)
            some_support_links = self._get_support_links(npa_batch)

            try:
                # Compare against the other new batch statements.
                other_npa_json_iter = db.select_all_batched(
                    *batching_args,
                    order_by=db.PAStatements.mk_hash,
                    skip_idx=outer_idx)
                for inner_idx, other_npa_json_batch in other_npa_json_iter:
                    other_npa_batch = [
                        _stmt_from_json(s_json)
                        for s_json, in other_npa_json_batch
                    ]
                    split_idx = len(npa_batch)
                    full_list = npa_batch + other_npa_batch
                    self._log("Comparing outer batch %d to inner batch %d of "
                              "other new statements." % (outer_idx, inner_idx))
                    some_support_links |= \
                        self._get_support_links(full_list, split_idx=split_idx)

                # Compare against the existing statements.
                opa_json_iter = db.select_all_batched(
                    self.batch_size, db.PAStatements.json,
                    db.PAStatements.create_date < start_date)
                for opa_idx, opa_json_batch in opa_json_iter:
                    opa_batch = [
                        _stmt_from_json(s_json) for s_json, in opa_json_batch
                    ]
                    split_idx = len(npa_batch)
                    full_list = npa_batch + opa_batch
                    self._log("Comparing new batch %d to batch %d of old "
                              "statements." % (outer_idx, opa_idx))
                    some_support_links |= \
                        self._get_support_links(full_list, split_idx=split_idx)
            finally:
                # Stash the new support links in case we crash.
                new_support_links |= (some_support_links - existing_links)
                with open(support_link_stash, 'wb') as f:
                    pickle.dump(
                        {
                            'existing links': new_support_links,
                            'ids done': npa_done
                        }, f)
            npa_done |= {s.get_hash(shallow=True) for s in npa_batch}

        # Insert any remaining support links.
        if new_support_links:
            self._log("Copying %d support links into db." %
                      len(new_support_links))
            db.copy('pa_support_links', new_support_links,
                    ('supported_mk_hash', 'supporting_mk_hash'))
            gatherer.add('links', len(new_support_links))
        self.__tag = 'Unpurposed'
        return

    @_handle_update_table
    @DGContext.wrap(gatherer)
    def supplement_corpus(self, db, continuing=False):
        """Update the table of preassembled statements.

        This method will take any new raw statements that have not yet been
        incorporated into the preassembled table, and use them to augment the
        preassembled table.

        The resulting updated table is indistinguishable from the result you
        would achieve if you had simply re-run preassembly on _all_ the
        raw statements.
        """

        self.pickle_stashes = []

        start_date, end_date = self._supplement_statements(db, continuing)
        self._supplement_support(db, start_date, end_date, continuing)

        # Remove all the caches so they can't be picked up accidentally later.
        for cache in self.pickle_stashes:
            if path.exists(cache):
                remove(cache)
        self.pickle_stashes = None

        return True

    def _log(self, msg, level='info'):
        """Applies a task specific tag to the log message."""
        if self.__print_logs:
            print("Preassembly Manager [%s] (%s): %s" %
                  (datetime.now(), self.__tag, msg))
        getattr(logger, level)("(%s) %s" % (self.__tag, msg))

    @clockit
    def _clean_statements(self, stmts):
        """Perform grounding, sequence mapping, and find unique set from stmts.

        This method returns a list of statement objects, as well as a set of
        tuples of the form (uuid, matches_key) which represent the links between
        raw (evidence) statements and their unique/preassembled counterparts.
        """
        eliminated_uuids = {}
        all_uuids = {s.uuid for s in stmts}
        self._log("Map grounding...")
        stmts = ac.map_grounding(stmts, use_adeft=True, gilda_mode='local')
        grounded_uuids = {s.uuid for s in stmts}
        eliminated_uuids['grounding'] = all_uuids - grounded_uuids
        self._log("Map sequences...")
        stmts = ac.map_sequence(stmts, use_cache=True)
        seqmapped_and_grounded_uuids = {s.uuid for s in stmts}
        eliminated_uuids['sequence mapping'] = \
            grounded_uuids - seqmapped_and_grounded_uuids
        return stmts, eliminated_uuids

    @clockit
    def _get_support_links(self, unique_stmts, split_idx=None):
        """Find the links of refinement/support between statements."""
        id_maps = self.pa._generate_id_maps(unique_stmts,
                                            poolsize=self.n_proc,
                                            split_idx=split_idx)
        ret = set()
        for ix_pair in id_maps:
            if ix_pair[0] == ix_pair[1]:
                assert False, "Self-comparison occurred."
            hash_pair = \
                tuple([shash(unique_stmts[ix]) for ix in ix_pair])
            if hash_pair[0] == hash_pair[1]:
                assert False, "Input list included duplicates."
            ret.add(hash_pair)

        return ret
Example #6
0
class DbPreassembler:
    """Class used to manage the preassembly pipeline

    Parameters
    ----------
    batch_size : int
        Select the maximum number of statements you wish to be handled at a
        time. In general, a larger batch size will somewhat be faster, but
        require much more memory.
    """
    def __init__(self,
                 batch_size=10000,
                 s3_cache=None,
                 print_logs=False,
                 stmt_type=None,
                 yes_all=False,
                 ontology=None):
        self.batch_size = batch_size
        if s3_cache is not None:
            # Make the cache specific to stmt type. This guards against
            # technical errors resulting from mixing this key parameter.
            if not isinstance(s3_cache, S3Path):
                raise TypeError(
                    f"Expected s3_cache to be type S3Path, but got "
                    f"type {type(s3_cache)}.")
            specifications = f'st_{stmt_type}/'
            self.s3_cache = s3_cache.get_element_path(specifications)

            # Report on what caches may already exist. This should hopefully
            # prevent re-doing work just because different batch sizes were
            # used.
            import boto3
            s3 = boto3.client('s3')
            if s3_cache.exists(s3):
                if self.s3_cache.exists(s3):
                    logger.info(f"A prior run with these parameters exists in "
                                f"the cache: {s3_cache}.")
                else:
                    logger.info(f"Prior job or jobs with different Statement "
                                f"type exist for the cache: {s3_cache}.")
            else:
                logger.info(f"No prior jobs appear in the cache: {s3_cache}.")
        else:
            self.s3_cache = None
        if ontology is None:
            ontology = bio_ontology
            ontology.initialize()
            ontology._build_transitive_closure()
        self.pa = Preassembler(ontology)
        self.__tag = 'Unpurposed'
        self.__print_logs = print_logs
        self.pickle_stashes = None
        self.stmt_type = stmt_type
        self.yes_all = yes_all
        return

    def _yes_input(self, message, default='yes'):
        if self.yes_all:
            return True

        valid = {'yes': True, 'ye': True, 'y': True, 'no': False, 'n': False}

        if default is None:
            prompt = '[y/n]'
        elif default == 'yes':
            prompt = '[Y/n]'
        elif default == 'no':
            prompt = '[y/N]'
        else:
            raise ValueError(f"Argument 'default' must be 'yes' or 'no', got "
                             f"'{default}'.")

        resp = input(f'{message} {prompt}: ')
        while True:
            if resp == '' and default is not None:
                return valid[default]
            elif resp.lower() in valid:
                return valid[resp.lower()]
            resp = input(f'Please answer "yes" (or "y") or "no" (or "n"). '
                         f'{prompt}: ')

    def _get_latest_updatetime(self, db):
        """Get the date of the latest update."""
        if self.stmt_type is not None:
            st_const = or_(db.PreassemblyUpdates.stmt_type == self.stmt_type,
                           db.PreassemblyUpdates.stmt_type.is_(None))
        else:
            st_const = db.PreassemblyUpdates.stmt_type.is_(None)
        update_list = db.select_all(db.PreassemblyUpdates, st_const)
        if not len(update_list):
            logger.warning("The preassembled corpus has not been initialized, "
                           "or else the updates table has not been populated.")
            return None
        return max([u.run_datetime for u in update_list])

    def _get_cache_path(self, file_name):
        return (self.s3_cache.get_element_path(
            self.__tag).get_element_path(file_name))

    def _init_cache(self, continuing):
        if self.s3_cache is None:
            return

        import boto3
        s3 = boto3.client('s3')
        start_file = self._get_cache_path('start.pkl')
        if start_file.exists(s3):
            s3_resp = start_file.get(s3)
            start_data = pickle.loads(s3_resp['Body'].read())
            start_time = start_data['start_time']
            cache_desc = f"Do you want to %s {self._get_cache_path('')} " \
                         f"started {start_time}?"
            if continuing:
                if self._yes_input(cache_desc % 'continue with'):
                    return start_time
                else:
                    raise UserQuit("Aborting job.")
            elif not self._yes_input(cache_desc % 'overwrite existing',
                                     default='no'):
                raise UserQuit("Aborting job.")

            self._clear_cache()

        start_time = datetime.utcnow()
        start_data = {'start_time': start_time}
        start_file.put(s3, pickle.dumps(start_data))
        return start_time

    def _clear_cache(self):
        if self.s3_cache is None:
            return
        import boto3
        s3 = boto3.client('s3')
        objects = self._get_cache_path('').list_objects(s3)
        for s3_path in objects:
            s3_path.delete(s3)
        return

    def _run_cached(self, continuing, func, *args, **kwargs):
        if self.s3_cache is None:
            return func(*args, **kwargs)

        # Define the location of this cache.
        import boto3
        s3 = boto3.client('s3')
        result_cache = self._get_cache_path(f'{func.__name__}.pkl')

        # If continuing, try to retrieve the file.
        if continuing and result_cache.exists(s3):
            s3_result = result_cache.get(s3)
            return pickle.loads(s3_result['Body'].read())

        # If not continuing or the file doesn't exist, run the function.
        results = func(*args, **kwargs)
        pickle_data = pickle.dumps(results)
        result_cache.put(s3, pickle_data)
        return results

    def _put_support_mark(self, outer_idx):
        if self.s3_cache is None:
            return

        import boto3
        s3 = boto3.client('s3')

        supp_file = self._get_cache_path('support_idx.pkl')
        supp_file.put(s3, pickle.dumps(outer_idx * self.batch_size))
        return

    def _get_support_mark(self, continuing):
        if self.s3_cache is None:
            return -1

        if not continuing:
            return -1

        import boto3
        s3 = boto3.client('s3')

        supp_file = self._get_cache_path('support_idx.pkl')
        if not supp_file.exists(s3):
            return -1
        s3_resp = supp_file.get(s3)
        return pickle.loads(s3_resp['Body'].read()) // self.batch_size

    def _raw_sid_stmt_iter(self, db, id_set, do_enumerate=False):
        """Return a generator over statements with the given database ids."""
        def _fixed_raw_stmt_from_json(s_json, tr):
            stmt = _stmt_from_json(s_json)
            if tr is not None:
                stmt.evidence[0].pmid = tr.pmid
                stmt.evidence[0].text_refs = {
                    k: v
                    for k, v in tr.__dict__.items() if not k.startswith('_')
                }
            return stmt

        i = 0
        for stmt_id_batch in batch_iter(id_set, self.batch_size):
            subres = (db.filter_query(
                [db.RawStatements.id, db.RawStatements.json, db.TextRef],
                db.RawStatements.id.in_(stmt_id_batch)).outerjoin(
                    db.Reading).outerjoin(db.TextContent).outerjoin(
                        db.TextRef).yield_per(self.batch_size // 10))
            data = [(sid, _fixed_raw_stmt_from_json(s_json, tr))
                    for sid, s_json, tr in subres]
            if do_enumerate:
                yield i, data
                i += 1
            else:
                yield data

    def _make_idx_batches(self, hash_list, continuing):
        N = len(hash_list)
        B = self.batch_size
        idx_batch_list = [(n * B, min((n + 1) * B, N))
                          for n in range(0, N // B + 1)]
        start_idx = self._get_support_mark(continuing) + 1
        return idx_batch_list, start_idx

    @clockit
    def _extract_and_push_unique_statements(self,
                                            db,
                                            raw_sids,
                                            num_stmts,
                                            mk_done=None):
        """Get the unique Statements from the raw statements."""
        self._log("There are %d distilled raw statement ids to preassemble." %
                  len(raw_sids))

        if mk_done is None:
            mk_done = set()

        new_mk_set = set()
        num_batches = num_stmts / self.batch_size
        for i, stmt_tpl_batch in self._raw_sid_stmt_iter(db, raw_sids, True):
            self._log("Processing batch %d/%d of %d/%d statements." %
                      (i, num_batches, len(stmt_tpl_batch), num_stmts))

            # Get a list of statements and generate a mapping from uuid to sid.
            stmts = []
            uuid_sid_dict = {}
            for sid, stmt in stmt_tpl_batch:
                uuid_sid_dict[stmt.uuid] = sid
                stmts.append(stmt)

            # Map groundings and sequences.
            cleaned_stmts, eliminated_uuids = self._clean_statements(stmts)
            discarded_stmts = [
                (uuid_sid_dict[uuid], reason)
                for reason, uuid_set in eliminated_uuids.items()
                for uuid in uuid_set
            ]
            db.copy('discarded_statements',
                    discarded_stmts, ('stmt_id', 'reason'),
                    commit=False)

            # Use the shallow hash to condense unique statements.
            new_unique_stmts, evidence_links, agent_tuples = \
                self._condense_statements(cleaned_stmts, mk_done, new_mk_set,
                                          uuid_sid_dict)

            # Insert the statements and their links.
            self._log("Insert new statements into database...")
            insert_pa_stmts(db,
                            new_unique_stmts,
                            ignore_agents=True,
                            commit=False)
            gatherer.add('stmts', len(new_unique_stmts))

            self._log("Insert new raw_unique links into the database...")
            ev_links = flatten_evidence_dict(evidence_links)
            db.copy('raw_unique_links',
                    ev_links, ('pa_stmt_mk_hash', 'raw_stmt_id'),
                    commit=False)
            gatherer.add('evidence', len(ev_links))

            db.copy_lazy('pa_agents',
                         hash_pa_agents(agent_tuples),
                         ('stmt_mk_hash', 'ag_num', 'db_name', 'db_id', 'role',
                          'agent_ref_hash'),
                         commit=False)
            insert_pa_agents(db,
                             new_unique_stmts,
                             verbose=True,
                             skip=['agents'])  # This will commit

        self._log("Added %d new pa statements into the database." %
                  len(new_mk_set))
        return new_mk_set

    @clockit
    def _condense_statements(self, cleaned_stmts, mk_done, new_mk_set,
                             uuid_sid_dict):
        self._log("Condense into unique statements...")
        new_unique_stmts = []
        evidence_links = defaultdict(lambda: set())
        agent_tuples = set()
        for s in cleaned_stmts:
            h = s.get_hash(refresh=True)

            # If this statement is new, make it.
            if h not in mk_done and h not in new_mk_set:
                new_unique_stmts.append(s.make_generic_copy())
                new_mk_set.add(h)

            # Add the evidence to the dict.
            evidence_links[h].add(uuid_sid_dict[s.uuid])

            # Add any db refs to the agents.
            ref_data, _, _ = extract_agent_data(s, h)
            agent_tuples |= set(ref_data)

        return new_unique_stmts, evidence_links, agent_tuples

    def _dump_links(self, db, supp_links):
        self._log(f"Copying batch of {len(supp_links)} support links into db.")
        skipped = db.copy_report_lazy(
            'pa_support_links', supp_links,
            ('supported_mk_hash', 'supporting_mk_hash'))
        gatherer.add('links', len(supp_links - set(skipped)))
        return

    @_handle_update_table
    @DGContext.wrap(gatherer)
    def create_corpus(self, db, continuing=False):
        """Initialize the table of preassembled statements.

        This method will find the set of unique knowledge represented in the
        table of raw statements, and it will populate the table of preassembled
        statements (PAStatements/pa_statements), while maintaining links between
        the raw statements and their unique (pa) counterparts. Furthermore, the
        refinement/support relationships between unique statements will be found
        and recorded in the PASupportLinks/pa_support_links table.

        For more detail on preassembly, see indra/preassembler/__init__.py
        """
        self.__tag = 'create'
        self._init_cache(continuing)

        if continuing:
            # Get discarded statements
            skip_ids = {
                i
                for i, in db.select_all(db.DiscardedStatements.stmt_id)
            }
            self._log("Found %d discarded statements from earlier run." %
                      len(skip_ids))

        # Get filtered statement ID's.
        if self.stmt_type is not None:
            clauses = [db.RawStatements.type == self.stmt_type]
        else:
            clauses = []
        stmt_ids = self._run_cached(continuing,
                                    distill_stmts,
                                    db,
                                    clauses=clauses)

        # Handle the possibility we're picking up after an earlier job...
        mk_done = set()
        if continuing:
            self._log("Getting set of statements already de-duplicated...")
            link_q = db.filter_query([
                db.RawUniqueLinks.raw_stmt_id,
                db.RawUniqueLinks.pa_stmt_mk_hash
            ])
            if self.stmt_type is not None:
                link_q = (link_q.join(db.RawStatements).filter(
                    db.RawStatements.type == self.stmt_type))
            link_resp = link_q.all()
            if link_resp:
                checked_raw_stmt_ids, pa_stmt_hashes = zip(*link_resp)
                stmt_ids -= set(checked_raw_stmt_ids)
                self._log("Found %d raw statements without links to unique." %
                          len(stmt_ids))
                stmt_ids -= skip_ids
                self._log("Found %d raw statements that still need to be "
                          "processed." % len(stmt_ids))
                mk_done = set(pa_stmt_hashes)
                self._log("Found %d preassembled statements already done." %
                          len(mk_done))

        # Get the set of unique statements
        new_mk_set = self._run_cached(continuing,
                                      self._extract_and_push_unique_statements,
                                      db, stmt_ids, len(stmt_ids), mk_done)

        # Now get the support links between all batches.
        support_links = set()
        hash_list = list(new_mk_set | mk_done)
        self._log(f"Beginning to find support relations for {len(hash_list)} "
                  f"new statements.")
        hash_list.sort()
        idx_batches, start_idx = self._make_idx_batches(hash_list, continuing)
        for outer_idx, (out_si, out_ei) in enumerate(idx_batches[start_idx:]):
            outer_idx += start_idx
            sj_query = db.filter_query(
                db.PAStatements.json,
                db.PAStatements.mk_hash.in_(hash_list[out_si:out_ei]))
            outer_batch = [_stmt_from_json(sj) for sj, in sj_query.all()]

            # Get internal support links
            self._log(f'Getting internal support links outer batch '
                      f'{outer_idx}/{len(idx_batches)-1}.')
            some_support_links = self._get_support_links(outer_batch)

            # Get links with all other batches
            in_start = outer_idx + 1
            for inner_idx, (in_si, in_ei) in enumerate(idx_batches[in_start:]):
                inner_sj_q = db.filter_query(
                    db.PAStatements.json,
                    db.PAStatements.mk_hash.in_(hash_list[in_si:in_ei]))
                inner_batch = [_stmt_from_json(sj) for sj, in inner_sj_q.all()]
                split_idx = len(inner_batch)
                full_list = inner_batch + outer_batch
                self._log(f'Getting support between outer batch {outer_idx}/'
                          f'{len(idx_batches)-1} and inner batch {inner_idx}/'
                          f'{len(idx_batches)-in_start-1}.')
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx)

            # Add all the new support links
            support_links |= some_support_links

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(support_links) >= self.batch_size:
                self._dump_links(db, support_links)
                self._put_support_mark(outer_idx)
                support_links = set()

        # Insert any remaining support links.
        if support_links:
            self._log('Final (overflow) batch of links.')
            self._dump_links(db, support_links)

        self._clear_cache()
        return True

    def _get_new_stmt_ids(self, db):
        """Get all the uuids of statements not included in evidence."""
        olds_q = db.filter_query(
            db.RawStatements.id,
            db.RawStatements.id == db.RawUniqueLinks.raw_stmt_id)
        if self.stmt_type is not None:
            olds_q = olds_q.filter(db.RawStatements.type == self.stmt_type)
        alls_q = db.filter_query(db.RawStatements.id)
        if self.stmt_type is not None:
            alls_q = alls_q.filter(db.RawStatements.type == self.stmt_type)
        new_id_q = alls_q.except_(olds_q)
        all_new_stmt_ids = {sid for sid, in new_id_q.all()}
        self._log("Found %d new statement ids." % len(all_new_stmt_ids))
        return all_new_stmt_ids

    def _supplement_statements(self, db, continuing=False):
        """Supplement the preassembled statements with the latest content."""

        last_update = self._get_latest_updatetime(db)
        assert last_update is not None, \
            "The preassembly tables have not yet been initialized."
        self._log("Latest update was: %s" % last_update)

        # Get the new statements...
        self._log("Loading info about the existing state of preassembly. "
                  "(This may take a little time)")
        new_ids = self._run_cached(continuing, self._get_new_stmt_ids, db)

        # Weed out exact duplicates.
        if self.stmt_type is not None:
            clauses = [db.RawStatements.type == self.stmt_type]
        else:
            clauses = []
        stmt_ids = self._run_cached(continuing,
                                    distill_stmts,
                                    db,
                                    get_full_stmts=False,
                                    clauses=clauses)

        # Get discarded statements
        skip_ids = {i for i, in db.select_all(db.DiscardedStatements.stmt_id)}

        # Select only the good new statement ids.
        new_stmt_ids = new_ids & stmt_ids - skip_ids

        # Get the set of new unique statements and link to any new evidence.
        old_mk_set = {mk for mk, in db.select_all(db.PAStatements.mk_hash)}
        self._log("Found %d old pa statements." % len(old_mk_set))

        new_mk_set = self._run_cached(continuing,
                                      self._extract_and_push_unique_statements,
                                      db, new_stmt_ids, len(new_stmt_ids),
                                      old_mk_set)

        if continuing:
            self._log("Original old mk set: %d" % len(old_mk_set))
            old_mk_set = old_mk_set - new_mk_set
            self._log("Adjusted old mk set: %d" % len(old_mk_set))

        self._log("Found %d new pa statements." % len(new_mk_set))
        return new_mk_set

    def _supplement_support(self,
                            db,
                            new_hashes,
                            start_time,
                            continuing=False):
        """Calculate the support for the given date range of pa statements."""
        if not isinstance(new_hashes, list):
            new_hashes = list(new_hashes)
        new_hashes.sort()

        # If we are continuing, check for support links that were already found
        support_links = set()
        idx_batches, start_idx = self._make_idx_batches(new_hashes, continuing)
        for outer_idx, (out_s, out_e) in enumerate(idx_batches[start_idx:]):
            outer_idx += start_idx
            # Create the statements from the jsons.
            npa_json_q = db.filter_query(
                db.PAStatements.json,
                db.PAStatements.mk_hash.in_(new_hashes[out_s:out_e]))
            npa_batch = [
                _stmt_from_json(s_json) for s_json in npa_json_q.all()
            ]

            # Compare internally
            self._log(f"Getting support for new pa batch {outer_idx}/"
                      f"{len(idx_batches)}.")
            some_support_links = self._get_support_links(npa_batch)

            # Compare against the other new batch statements.
            in_start = outer_idx + 1
            for in_idx, (in_s, in_e) in enumerate(idx_batches[in_start:]):
                other_npa_q = db.filter_query(
                    db.PAStatements.json,
                    db.PAStatements.mk_hash.in_(new_hashes[in_s:in_e]))
                other_npa_batch = [
                    _stmt_from_json(sj) for sj, in other_npa_q.all()
                ]
                split_idx = len(npa_batch)
                full_list = npa_batch + other_npa_batch
                self._log(f"Comparing outer batch {outer_idx}/"
                          f"{len(idx_batches)-1} to inner batch {in_idx}/"
                          f"{len(idx_batches)-in_start-1} of other new "
                          f"statements.")
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx)

            # Compare against the existing statements.
            opa_args = (db.PAStatements.create_date < start_time, )
            if self.stmt_type is not None:
                opa_args += (db.PAStatements.type == self.stmt_type, )

            opa_json_iter = db.select_all_batched(self.batch_size,
                                                  db.PAStatements.json,
                                                  *opa_args)
            for opa_idx, opa_json_batch in opa_json_iter:
                opa_batch = [
                    _stmt_from_json(s_json) for s_json, in opa_json_batch
                ]
                split_idx = len(npa_batch)
                full_list = npa_batch + opa_batch
                self._log(f"Comparing new batch {outer_idx}/"
                          f"{len(idx_batches)-1} to batch {opa_idx} of old pa "
                          f"statements.")
                some_support_links |= \
                    self._get_support_links(full_list, split_idx=split_idx)

            support_links |= some_support_links

            # There are generally few support links compared to the number of
            # statements, so it doesn't make sense to copy every time, but for
            # long preassembly, this allows for better failure recovery.
            if len(support_links) >= self.batch_size:
                self._dump_links(db, support_links)
                self._put_support_mark(outer_idx)
                support_links = set()

        # Insert any remaining support links.
        if support_links:
            self._log("Final (overflow) batch of new support links.")
            self._dump_links(db, support_links)
        return

    @_handle_update_table
    @DGContext.wrap(gatherer)
    def supplement_corpus(self, db, continuing=False):
        """Update the table of preassembled statements.

        This method will take any new raw statements that have not yet been
        incorporated into the preassembled table, and use them to augment the
        preassembled table.

        The resulting updated table is indistinguishable from the result you
        would achieve if you had simply re-run preassembly on _all_ the
        raw statements.
        """
        self.__tag = 'supplement'
        start_time = self._init_cache(continuing)

        self.pickle_stashes = []

        new_hashes = self._supplement_statements(db, continuing)
        self._supplement_support(db, new_hashes, start_time, continuing)

        self._clear_cache()
        self.__tag = 'Unpurposed'
        return True

    def _log(self, msg, level='info'):
        """Applies a task specific tag to the log message."""
        if self.__print_logs:
            print("Preassembly Manager [%s] (%s): %s" %
                  (datetime.now(), self.__tag, msg))
        getattr(logger, level)("(%s) %s" % (self.__tag, msg))

    @clockit
    def _clean_statements(self, stmts):
        """Perform grounding, sequence mapping, and find unique set from stmts.

        This method returns a list of statement objects, as well as a set of
        tuples of the form (uuid, matches_key) which represent the links between
        raw (evidence) statements and their unique/preassembled counterparts.
        """
        eliminated_uuids = {}
        all_uuids = {s.uuid for s in stmts}
        self._log("Map grounding...")
        stmts = ac.map_grounding(stmts, use_adeft=True, gilda_mode='local')
        grounded_uuids = {s.uuid for s in stmts}
        eliminated_uuids['grounding'] = all_uuids - grounded_uuids
        self._log("Map sequences...")
        stmts = ac.map_sequence(stmts, use_cache=True)
        seqmapped_and_grounded_uuids = {s.uuid for s in stmts}
        eliminated_uuids['sequence mapping'] = \
            grounded_uuids - seqmapped_and_grounded_uuids
        return stmts, eliminated_uuids

    @clockit
    def _get_support_links(self, unique_stmts, split_idx=None):
        """Find the links of refinement/support between statements."""
        id_maps = self.pa._generate_id_maps(unique_stmts, split_idx=split_idx)
        ret = set()
        for ix_pair in id_maps:
            if ix_pair[0] == ix_pair[1]:
                assert False, "Self-comparison occurred."
            hash_pair = \
                tuple([shash(unique_stmts[ix]) for ix in ix_pair])
            if hash_pair[0] == hash_pair[1]:
                assert False, "Input list included duplicates."
            ret.add(hash_pair)

        return ret