def _check_preassembly_with_database(num_stmts, batch_size):
    db = get_pa_loaded_db(num_stmts)

    # Now test the set of preassembled (pa) statements from the database
    # against what we get from old-fashioned preassembly (opa).
    opa_inp_stmts = _get_opa_input_stmts(db)

    # Get the set of raw statements.
    raw_stmt_list = db.select_all(db.RawStatements)
    all_raw_ids = {raw_stmt.id for raw_stmt in raw_stmt_list}
    assert len(raw_stmt_list)

    # Run the preassembly initialization.
    start = datetime.now()
    pa_manager = pm.PreassemblyManager(batch_size=batch_size, print_logs=True)
    pa_manager.create_corpus(db)
    end = datetime.now()
    print("Duration:", end - start)

    # Make sure the number of pa statements is within reasonable bounds.
    pa_stmt_list = db.select_all(db.PAStatements)
    assert 0 < len(pa_stmt_list) < len(raw_stmt_list)

    # Check the evidence links.
    raw_unique_link_list = db.select_all(db.RawUniqueLinks)
    assert len(raw_unique_link_list)
    all_link_ids = {ru.raw_stmt_id for ru in raw_unique_link_list}
    all_link_mk_hashes = {ru.pa_stmt_mk_hash for ru in raw_unique_link_list}
    assert len(all_link_ids - all_raw_ids) is 0
    assert all(
        [pa_stmt.mk_hash in all_link_mk_hashes for pa_stmt in pa_stmt_list])

    # Check the support links.
    sup_links = db.select_all([
        db.PASupportLinks.supporting_mk_hash,
        db.PASupportLinks.supported_mk_hash
    ])
    assert sup_links
    assert not any([l[0] == l[1] for l in sup_links]),\
        "Found self-support in the database."

    # Try to get all the preassembled statements from the table.
    pa_stmts = db_client.get_statements([],
                                        preassembled=True,
                                        db=db,
                                        with_support=True)
    assert len(pa_stmts) == len(pa_stmt_list), (len(pa_stmts),
                                                len(pa_stmt_list))

    self_supports = {
        shash(s): shash(s)
        in {shash(s_)
            for s_ in s.supported_by + s.supports}
        for s in pa_stmts
    }
    if any(self_supports.values()):
        assert False, "Found self-support in constructed pa statement objects."

    _check_against_opa_stmts(db, opa_inp_stmts, pa_stmts)
    return
 def _compare_list_elements(label, list_func, comp_func, **stmts):
     (stmt_1_name, stmt_1), (stmt_2_name, stmt_2) = list(stmts.items())
     vals_1 = [comp_func(elem) for elem in list_func(stmt_1)]
     vals_2 = []
     for element in list_func(stmt_2):
         val = comp_func(element)
         if val in vals_1:
             vals_1.remove(val)
         else:
             vals_2.append(val)
     if len(vals_1) or len(vals_2):
         print("Found mismatched %s for hash %s:\n\t%s=%s\n\t%s=%s" %
               (label, shash(stmt_1), stmt_1_name, vals_1, stmt_2_name,
                vals_2))
         return {
             'diffs': {
                 stmt_1_name: vals_1,
                 stmt_2_name: vals_2
             },
             'stmts': {
                 stmt_1_name: stmt_1,
                 stmt_2_name: stmt_2
             }
         }
     return None
def elaborate_on_hash_diffs(db, lbl, stmt_list, other_stmt_keys):
    print("#" * 100)
    print("Elaboration on extra %s statements:" % lbl)
    print("#" * 100)
    for s in stmt_list:
        print(s)
        uuid = s.uuid
        print('-' * 100)
        print('uuid: %s\nhash: %s\nshallow hash: %s' %
              (s.uuid, s.get_hash(shallow=False), shash(s)))
        print('-' * 100)
        db_pas = db.select_one(db.PAStatements,
                               db.PAStatements.mk_hash == shash(s))
        print('\tPA statement:', db_pas.__dict__ if db_pas else '~')
        print('-' * 100)
        db_s = db.select_one(db.RawStatements, db.RawStatements.uuid == s.uuid)
        print('\tRaw statement:', str_imp(db_s, uuid, other_stmt_keys))
        if db_s is None:
            continue
        print('-' * 100)
        if db_s.reading_id is None:
            print("Statement was from a database: %s" % db_s.db_info_id)
            continue
        db_r = db.select_one(db.Reading, db.Reading.id == db_s.reading_id)
        print('\tReading:', str_imp(db_r))
        tc = db.select_one(db.TextContent,
                           db.TextContent.id == db_r.text_content_id)
        print('\tText Content:', str_imp(tc))
        tr = db.select_one(db.TextRef, db.TextRef.id == tc.text_ref_id)
        print('\tText ref:', str_imp(tr))
        print('-' * 100)
        for tc in db.select_all(db.TextContent,
                                db.TextContent.text_ref_id == tr.id):
            print('\t', str_imp(tc))
            for r in db.select_all(db.Reading,
                                   db.Reading.text_content_id == tc.id):
                print('\t\t', str_imp(r))
                for s in db.select_all(db.RawStatements,
                                       db.RawStatements.reading_id == r.id):
                    print('\t\t\t', str_imp(s, uuid, other_stmt_keys))
        print('=' * 100)
Beispiel #4
0
def str_imp(o, uuid=None, other_stmt_keys=None):
    if o is None:
        return '~'
    cname = o.__class__.__name__
    if cname == 'TextRef':
        return ('<TextRef: trid: %s, pmid: %s, pmcid: %s>'
                % (o.id, o.pmid, o.pmcid))
    if cname == 'TextContent':
        return ('<TextContent: tcid: %s, trid: %s, src: %s>'
                % (o.id, o.text_ref_id, o.source))
    if cname == 'Reading':
        return ('<Reading: rid: %s, tcid: %s, reader: %s, rv: %s>'
                % (o.id, o.text_content_id, o.reader, o.reader_version))
    if cname == 'RawStatements':
        s = Statement._from_json(json.loads(o.json.decode()))
        s_str = ('<RawStmt: %s sid: %s, uuid: %s, type: %s, iv: %s, hash: %s>'
                 % (str(s), o.id, o.uuid[:8] + '...', o.type,
                    o.indra_version[:14] + '...', o.mk_hash))
        if other_stmt_keys and shash(s) in other_stmt_keys:
            s_str = '+' + s_str
        if s.uuid == uuid:
            s_str = '*' + s_str
        return s_str
def _check_against_opa_stmts(db, raw_stmts, pa_stmts):
    def _compare_list_elements(label, list_func, comp_func, **stmts):
        (stmt_1_name, stmt_1), (stmt_2_name, stmt_2) = list(stmts.items())
        vals_1 = [comp_func(elem) for elem in list_func(stmt_1)]
        vals_2 = []
        for element in list_func(stmt_2):
            val = comp_func(element)
            if val in vals_1:
                vals_1.remove(val)
            else:
                vals_2.append(val)
        if len(vals_1) or len(vals_2):
            print("Found mismatched %s for hash %s:\n\t%s=%s\n\t%s=%s" %
                  (label, shash(stmt_1), stmt_1_name, vals_1, stmt_2_name,
                   vals_2))
            return {
                'diffs': {
                    stmt_1_name: vals_1,
                    stmt_2_name: vals_2
                },
                'stmts': {
                    stmt_1_name: stmt_1,
                    stmt_2_name: stmt_2
                }
            }
        return None

    opa_stmts = _do_old_fashioned_preassembly(raw_stmts)

    old_stmt_dict = {shash(s): s for s in opa_stmts}
    new_stmt_dict = {shash(s): s for s in pa_stmts}

    new_hash_set = set(new_stmt_dict.keys())
    old_hash_set = set(old_stmt_dict.keys())
    hash_diffs = {
        'extra_new': [new_stmt_dict[h] for h in new_hash_set - old_hash_set],
        'extra_old': [old_stmt_dict[h] for h in old_hash_set - new_hash_set]
    }
    if hash_diffs['extra_new']:
        elaborate_on_hash_diffs(db, 'new', hash_diffs['extra_new'],
                                old_stmt_dict.keys())
    if hash_diffs['extra_old']:
        elaborate_on_hash_diffs(db, 'old', hash_diffs['extra_old'],
                                new_stmt_dict.keys())
    print(hash_diffs)
    tests = [{
        'funcs': {
            'list': lambda s: s.evidence[:],
            'comp': lambda ev: '%s-%s-%s' % (ev.source_api, ev.pmid, ev.text)
        },
        'label': 'evidence text',
        'results': []
    }, {
        'funcs': {
            'list': lambda s: s.supports[:],
            'comp': lambda s: shash(s)
        },
        'label': 'supports matches keys',
        'results': []
    }, {
        'funcs': {
            'list': lambda s: s.supported_by[:],
            'comp': lambda s: shash(s)
        },
        'label': 'supported-by matches keys',
        'results': []
    }]
    comp_hashes = new_hash_set & old_hash_set
    for mk_hash in comp_hashes:
        for test_dict in tests:
            res = _compare_list_elements(test_dict['label'],
                                         test_dict['funcs']['list'],
                                         test_dict['funcs']['comp'],
                                         new_stmt=new_stmt_dict[mk_hash],
                                         old_stmt=old_stmt_dict[mk_hash])
            if res is not None:
                test_dict['results'].append(res)

    def all_tests_passed():
        test_results = [not any(hash_diffs.values())]
        for td in tests:
            test_results.append(len(td['results']) == 0)
        print("%d/%d tests passed." % (sum(test_results), len(test_results)))
        return all(test_results)

    def write_report(num_comps):
        ret_str = "Some tests failed:\n"
        ret_str += (
            'Found %d/%d extra old stmts and %d/%d extra new stmts.\n' %
            (len(hash_diffs['extra_old']), len(old_hash_set),
             len(hash_diffs['extra_new']), len(new_hash_set)))
        for td in tests:
            ret_str += ('Found %d/%d mismatches in %s.\n' %
                        (len(td['results']), num_comps, td['label']))
        return ret_str

    # Now evaluate the results for exceptions
    assert all_tests_passed(), write_report(len(comp_hashes))