def _check_preassembly_with_database(num_stmts, batch_size): db = get_pa_loaded_db(num_stmts) # Now test the set of preassembled (pa) statements from the database # against what we get from old-fashioned preassembly (opa). opa_inp_stmts = _get_opa_input_stmts(db) # Get the set of raw statements. raw_stmt_list = db.select_all(db.RawStatements) all_raw_ids = {raw_stmt.id for raw_stmt in raw_stmt_list} assert len(raw_stmt_list) # Run the preassembly initialization. start = datetime.now() pa_manager = pm.PreassemblyManager(batch_size=batch_size, print_logs=True) pa_manager.create_corpus(db) end = datetime.now() print("Duration:", end - start) # Make sure the number of pa statements is within reasonable bounds. pa_stmt_list = db.select_all(db.PAStatements) assert 0 < len(pa_stmt_list) < len(raw_stmt_list) # Check the evidence links. raw_unique_link_list = db.select_all(db.RawUniqueLinks) assert len(raw_unique_link_list) all_link_ids = {ru.raw_stmt_id for ru in raw_unique_link_list} all_link_mk_hashes = {ru.pa_stmt_mk_hash for ru in raw_unique_link_list} assert len(all_link_ids - all_raw_ids) is 0 assert all( [pa_stmt.mk_hash in all_link_mk_hashes for pa_stmt in pa_stmt_list]) # Check the support links. sup_links = db.select_all([ db.PASupportLinks.supporting_mk_hash, db.PASupportLinks.supported_mk_hash ]) assert sup_links assert not any([l[0] == l[1] for l in sup_links]),\ "Found self-support in the database." # Try to get all the preassembled statements from the table. pa_stmts = db_client.get_statements([], preassembled=True, db=db, with_support=True) assert len(pa_stmts) == len(pa_stmt_list), (len(pa_stmts), len(pa_stmt_list)) self_supports = { shash(s): shash(s) in {shash(s_) for s_ in s.supported_by + s.supports} for s in pa_stmts } if any(self_supports.values()): assert False, "Found self-support in constructed pa statement objects." _check_against_opa_stmts(db, opa_inp_stmts, pa_stmts) return
def _compare_list_elements(label, list_func, comp_func, **stmts): (stmt_1_name, stmt_1), (stmt_2_name, stmt_2) = list(stmts.items()) vals_1 = [comp_func(elem) for elem in list_func(stmt_1)] vals_2 = [] for element in list_func(stmt_2): val = comp_func(element) if val in vals_1: vals_1.remove(val) else: vals_2.append(val) if len(vals_1) or len(vals_2): print("Found mismatched %s for hash %s:\n\t%s=%s\n\t%s=%s" % (label, shash(stmt_1), stmt_1_name, vals_1, stmt_2_name, vals_2)) return { 'diffs': { stmt_1_name: vals_1, stmt_2_name: vals_2 }, 'stmts': { stmt_1_name: stmt_1, stmt_2_name: stmt_2 } } return None
def elaborate_on_hash_diffs(db, lbl, stmt_list, other_stmt_keys): print("#" * 100) print("Elaboration on extra %s statements:" % lbl) print("#" * 100) for s in stmt_list: print(s) uuid = s.uuid print('-' * 100) print('uuid: %s\nhash: %s\nshallow hash: %s' % (s.uuid, s.get_hash(shallow=False), shash(s))) print('-' * 100) db_pas = db.select_one(db.PAStatements, db.PAStatements.mk_hash == shash(s)) print('\tPA statement:', db_pas.__dict__ if db_pas else '~') print('-' * 100) db_s = db.select_one(db.RawStatements, db.RawStatements.uuid == s.uuid) print('\tRaw statement:', str_imp(db_s, uuid, other_stmt_keys)) if db_s is None: continue print('-' * 100) if db_s.reading_id is None: print("Statement was from a database: %s" % db_s.db_info_id) continue db_r = db.select_one(db.Reading, db.Reading.id == db_s.reading_id) print('\tReading:', str_imp(db_r)) tc = db.select_one(db.TextContent, db.TextContent.id == db_r.text_content_id) print('\tText Content:', str_imp(tc)) tr = db.select_one(db.TextRef, db.TextRef.id == tc.text_ref_id) print('\tText ref:', str_imp(tr)) print('-' * 100) for tc in db.select_all(db.TextContent, db.TextContent.text_ref_id == tr.id): print('\t', str_imp(tc)) for r in db.select_all(db.Reading, db.Reading.text_content_id == tc.id): print('\t\t', str_imp(r)) for s in db.select_all(db.RawStatements, db.RawStatements.reading_id == r.id): print('\t\t\t', str_imp(s, uuid, other_stmt_keys)) print('=' * 100)
def str_imp(o, uuid=None, other_stmt_keys=None): if o is None: return '~' cname = o.__class__.__name__ if cname == 'TextRef': return ('<TextRef: trid: %s, pmid: %s, pmcid: %s>' % (o.id, o.pmid, o.pmcid)) if cname == 'TextContent': return ('<TextContent: tcid: %s, trid: %s, src: %s>' % (o.id, o.text_ref_id, o.source)) if cname == 'Reading': return ('<Reading: rid: %s, tcid: %s, reader: %s, rv: %s>' % (o.id, o.text_content_id, o.reader, o.reader_version)) if cname == 'RawStatements': s = Statement._from_json(json.loads(o.json.decode())) s_str = ('<RawStmt: %s sid: %s, uuid: %s, type: %s, iv: %s, hash: %s>' % (str(s), o.id, o.uuid[:8] + '...', o.type, o.indra_version[:14] + '...', o.mk_hash)) if other_stmt_keys and shash(s) in other_stmt_keys: s_str = '+' + s_str if s.uuid == uuid: s_str = '*' + s_str return s_str
def _check_against_opa_stmts(db, raw_stmts, pa_stmts): def _compare_list_elements(label, list_func, comp_func, **stmts): (stmt_1_name, stmt_1), (stmt_2_name, stmt_2) = list(stmts.items()) vals_1 = [comp_func(elem) for elem in list_func(stmt_1)] vals_2 = [] for element in list_func(stmt_2): val = comp_func(element) if val in vals_1: vals_1.remove(val) else: vals_2.append(val) if len(vals_1) or len(vals_2): print("Found mismatched %s for hash %s:\n\t%s=%s\n\t%s=%s" % (label, shash(stmt_1), stmt_1_name, vals_1, stmt_2_name, vals_2)) return { 'diffs': { stmt_1_name: vals_1, stmt_2_name: vals_2 }, 'stmts': { stmt_1_name: stmt_1, stmt_2_name: stmt_2 } } return None opa_stmts = _do_old_fashioned_preassembly(raw_stmts) old_stmt_dict = {shash(s): s for s in opa_stmts} new_stmt_dict = {shash(s): s for s in pa_stmts} new_hash_set = set(new_stmt_dict.keys()) old_hash_set = set(old_stmt_dict.keys()) hash_diffs = { 'extra_new': [new_stmt_dict[h] for h in new_hash_set - old_hash_set], 'extra_old': [old_stmt_dict[h] for h in old_hash_set - new_hash_set] } if hash_diffs['extra_new']: elaborate_on_hash_diffs(db, 'new', hash_diffs['extra_new'], old_stmt_dict.keys()) if hash_diffs['extra_old']: elaborate_on_hash_diffs(db, 'old', hash_diffs['extra_old'], new_stmt_dict.keys()) print(hash_diffs) tests = [{ 'funcs': { 'list': lambda s: s.evidence[:], 'comp': lambda ev: '%s-%s-%s' % (ev.source_api, ev.pmid, ev.text) }, 'label': 'evidence text', 'results': [] }, { 'funcs': { 'list': lambda s: s.supports[:], 'comp': lambda s: shash(s) }, 'label': 'supports matches keys', 'results': [] }, { 'funcs': { 'list': lambda s: s.supported_by[:], 'comp': lambda s: shash(s) }, 'label': 'supported-by matches keys', 'results': [] }] comp_hashes = new_hash_set & old_hash_set for mk_hash in comp_hashes: for test_dict in tests: res = _compare_list_elements(test_dict['label'], test_dict['funcs']['list'], test_dict['funcs']['comp'], new_stmt=new_stmt_dict[mk_hash], old_stmt=old_stmt_dict[mk_hash]) if res is not None: test_dict['results'].append(res) def all_tests_passed(): test_results = [not any(hash_diffs.values())] for td in tests: test_results.append(len(td['results']) == 0) print("%d/%d tests passed." % (sum(test_results), len(test_results))) return all(test_results) def write_report(num_comps): ret_str = "Some tests failed:\n" ret_str += ( 'Found %d/%d extra old stmts and %d/%d extra new stmts.\n' % (len(hash_diffs['extra_old']), len(old_hash_set), len(hash_diffs['extra_new']), len(new_hash_set))) for td in tests: ret_str += ('Found %d/%d mismatches in %s.\n' % (len(td['results']), num_comps, td['label'])) return ret_str # Now evaluate the results for exceptions assert all_tests_passed(), write_report(len(comp_hashes))