def test_subclass_before_meta_init(): """Test if mention (candidate) subclass can be created before Meta init.""" Part = mention_subclass("Part") logger.info(f"Create a mention subclass '{Part.__tablename__}'") Meta.init("postgresql://localhost:5432/" + DB).Session() Temp = mention_subclass("Temp") logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
def test_subclass_before_meta_init(): """Test if it is possible to create a mention (candidate) subclass even before Meta is initialized. """ Part = mention_subclass("Part") logger.info(f"Create a mention subclass '{Part.__tablename__}'") Meta.init(CONN_STRING).Session() Temp = mention_subclass("Temp") logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
def test_subclass_before_meta_init(caplog): """Test if it is possible to create a mention (candidate) subclass even before Meta is initialized. """ caplog.set_level(logging.INFO) conn_string = "postgresql://localhost:5432/" + DB Part = mention_subclass("Part") logger.info(f"Create a mention subclass '{Part.__tablename__}'") Meta.init(conn_string).Session() Temp = mention_subclass("Temp") logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
def test_subclass_before_meta_init(): """Test if mention (candidate) subclass can be created before Meta init.""" # Test if mention (candidate) subclass can be created Part = mention_subclass("Part") logger.info(f"Create a mention subclass '{Part.__tablename__}'") # Setup a database con = psycopg2.connect( host=os.environ["POSTGRES_HOST"], port=os.environ["POSTGRES_PORT"], user=os.environ["PGUSER"], password=os.environ["PGPASSWORD"], ) con.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = con.cursor() cursor.execute(f'create database "{DB}";') session = Meta.init(CONN_STRING).Session() # Test if another mention subclass can be created Temp = mention_subclass("Temp") logger.info(f"Create a mention subclass '{Temp.__tablename__}'") # Teardown the database session.close() Meta.engine.dispose() Meta.engine = None cursor.execute(f'drop database "{DB}";') cursor.close() con.close()
def test_multimodal_cand(caplog): """Test multimodal candidate generation""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/radiology.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 35 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor = MentionExtractor(session, ms, m, matchers, parallelism=PARALLEL) mention_extractor.apply(docs) assert session.query(ms_doc).count() == 1 assert session.query(ms_cap).count() == 2 assert session.query(ms_sec).count() == 5 assert session.query(ms_tab).count() == 2 assert session.query(ms_fig).count() == 2 assert session.query(ms_para).count() == 30 assert session.query(ms_sent).count() == 35 assert session.query(ms_cell).count() == 21
def test_meta_connection_strings(): """Simple sanity checks for validating postgres connection strings.""" with pytest.raises(ValueError): Meta.init("postgresql" + DB).Session() with pytest.raises(ValueError): Meta.init("sqlite://somethingsilly" + DB).Session() with pytest.raises(ValueError): Meta.init("postgresql://somethingsilly:5432/").Session() Meta.init("postgresql://localhost:5432/" + DB).Session() assert Meta.DBNAME == DB Meta.init("postgresql://localhost:5432/" + "cand_test").Session() assert Meta.DBNAME == "cand_test"
def test_spacy_integration(caplog): """Run a simple e2e parse using spaCy as our parser. The point of this test is to actually use the DB just as would be done in a notebook by a user. """ # caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) PARALLEL = 2 # Travis only gives 2 cores session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session() docs_path = 'tests/data/html_simple/' pdf_path = 'tests/data/pdf_simple/' max_docs = 2 doc_preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs) corpus_parser = OmniParser( structural=True, lingual=True, visual=False, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() for doc in docs: logger.info("Doc: {}".format(doc.name)) for phrase in doc.phrases: logger.info(" Phrase: {}".format(phrase.text)) assert session.query(Document).count() == 2 assert session.query(Phrase).count() == 81
def parse(html_location, database): """ Wrapper function for calling Fonduer parser. :param html_location: HTML files generated by ``parse_preprocess.py``. :param database: db connection string. """ session = Meta.init(database).Session() doc_preprocessor = HTMLDocPreprocessor(html_location) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor)
def load_context(self, context: PythonModelContext) -> None: # Configure logging for Fonduer init_logging(log_dir="logs") logger.info("loading context") pyfunc_conf = _get_flavor_configuration(model_path=self.model_path, flavor_name=pyfunc.FLAVOR_NAME) conn_string = pyfunc_conf.get(CONN_STRING, None) if conn_string is None: raise RuntimeError("conn_string is missing from MLmodel file.") self.parallel = pyfunc_conf.get(PARALLEL, 1) session = Meta.init(conn_string).Session() logger.info("Getting parser") self.corpus_parser = self._get_parser(session) logger.info("Getting mention extractor") self.mention_extractor = self._get_mention_extractor(session) logger.info("Getting candidate extractor") self.candidate_extractor = self._get_candidate_extractor(session) candidate_classes = self.candidate_extractor.candidate_classes self.model_type = pyfunc_conf.get(MODEL_TYPE, "discriminative") if self.model_type == "discriminative": self.featurizer = Featurizer(session, candidate_classes) with open(os.path.join(self.model_path, "feature_keys.pkl"), "rb") as f: key_names = pickle.load(f) self.featurizer.drop_keys(key_names) self.featurizer.upsert_keys(key_names) disc_model = LogisticRegression() # Workaround to https://github.com/HazyResearch/fonduer/issues/208 checkpoint = torch.load( os.path.join(self.model_path, "best_model.pt")) disc_model.settings = checkpoint["config"] disc_model.cardinality = checkpoint["cardinality"] disc_model._build_model() disc_model.load(model_file="best_model.pt", save_dir=self.model_path) self.disc_model = disc_model else: self.labeler = Labeler(session, candidate_classes) with open(os.path.join(self.model_path, "labeler_keys.pkl"), "rb") as f: key_names = pickle.load(f) self.labeler.drop_keys(key_names) self.labeler.upsert_keys(key_names) self.gen_models = [ LabelModel.load( os.path.join(self.model_path, _.__name__ + ".pkl")) for _ in candidate_classes ]
def test_parse_error_doc_skipping(): """Test skipping of faulty htmls.""" faulty_doc_path = "tests/data/html_faulty/ext_diseases_missing_table_tag.html" preprocessor = HTMLDocPreprocessor(faulty_doc_path) session = Meta.init(CONN_STRING).Session() corpus_parser = Parser(session) corpus_parser.apply(preprocessor) # This returns documents that apply() was called on assert corpus_parser.last_docs == {"ext_diseases_missing_table_tag"} # This returns only documents that are successfully parsed. assert corpus_parser.get_last_documents() == []
def test_parse_structure(caplog): """Unit test of OmniParserUDF.parse_structure(). This only tests the structural parse of the document. """ caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session() max_docs = 1 docs_path = 'tests/data/html_simple/md.html' pdf_path = 'tests/data/pdf_simple/md.pdf' # Preprocessor for the Docs preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs) # Grab one document, text tuple from the preprocessor doc, text = next(preprocessor.generate()) logger.info(" Text: {}".format(text)) # Create an OmniParserUDF omni_udf = OmniParserUDF( True, # structural ["style"], # blacklist ["span", "br"], # flatten '', # flatten delim True, # lingual True, # strip [(u'[\u2010\u2011\u2012\u2013\u2014\u2212\uf02d]', '-')], # replace True, # tabular True, # visual pdf_path, # pdf path Spacy()) # lingual parser # Grab the phrases parsed by the OmniParser phrases = list(omni_udf.parse_structure(doc, text)) logger.warning("Doc: {}".format(doc)) for phrase in phrases: logger.warning(" Phrase: {}".format(phrase.text)) header = phrases[0] # Test structural attributes assert header.xpath == '/html/body/h1' assert header.html_tag == 'h1' assert header.html_attrs == ['id=sample-markdown'] # Test the unicode parse of delta assert (phrases[-1].text == "δ13Corg") # phrases expected in the "md" document. assert len(phrases) == 45
def parse(html_location, database, parallelism=1): """ Wrapper function for calling Fonduer parser :param html_location: HTML files generated by ``parse_preprocess.py`` :param database: DB connection string :param parallelism: Number of cores to use """ session = Meta.init(database).Session() doc_preprocessor = HTMLDocPreprocessor(html_location) corpus_parser = Parser(session, structural=True, lingual=True, parallelism=parallelism) corpus_parser.apply(doc_preprocessor)
def test_parse_style(caplog): """Test style tag parsing.""" caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session() # SpaCy on mac has issue on parallel parseing if os.name == "posix": PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 1 docs_path = "tests/data/html_extended/ext_diseases.html" pdf_path = "tests/data/pdf_extended/ext_diseases.pdf" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the md document omni = Parser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) omni.apply(preprocessor, parallelism=PARALLEL) # Grab the document doc = session.query(Document).order_by(Document.name).all()[0] # Grab the sentences parsed by the Parser sentences = list(session.query(Sentence).order_by(Sentence.position).all()) logger.warning("Doc: {}".format(doc)) for i, sentence in enumerate(sentences): logger.warning(" Sentence[{}]: {}".format(i, sentence.html_attrs)) # sentences for testing sub_sentences = [ { "index": 6, "attr": [ "class=col-header", "hobbies=work:hard;play:harder", "type=phenotype", "style=background: #f1f1f1; color: aquamarine; font-size: 18px;", ], }, {"index": 9, "attr": ["class=row-header", "style=background: #f1f1f1;"]}, {"index": 11, "attr": ["class=cell", "style=text-align: center;"]}, ] # Assertions assert all(sentences[p["index"]].html_attrs == p["attr"] for p in sub_sentences)
def test_parse_document_md(caplog): """Unit test of OmniParser on a single document. This tests both the structural and visual parse of the document. This also serves as a test of single-threaded parsing. """ logger = logging.getLogger(__name__) session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session() PARALLEL = 1 max_docs = 2 docs_path = 'tests/data/html_simple/' pdf_path = 'tests/data/pdf_simple/' # Preprocessor for the Docs preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs) # Create an OmniParser and parse the md document omni = OmniParser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) omni.apply(preprocessor, parallelism=PARALLEL) # Grab the md document doc = session.query(Document).order_by(Document.name).all()[1] logger.info("Doc: {}".format(doc)) for phrase in doc.phrases: logger.info(" Phrase: {}".format(phrase.text)) header = doc.phrases[0] # Test structural attributes assert header.xpath == '/html/body/h1' assert header.html_tag == 'h1' assert header.html_attrs == ['id=sample-markdown'] # Test visual attributes assert header.page == [1, 1] assert header.top == [35, 35] assert header.bottom == [61, 61] assert header.right == [111, 231] assert header.left == [35, 117] # Test lingual attributes assert header.ner_tags == ['O', 'O'] assert header.dep_labels == ['compound', 'ROOT'] # 44 phrases expected in the "md" document. assert len(doc.phrases) == 45
def test_visualizer(caplog): from fonduer.utils.visualizer import Visualizer # noqa """Unit test of visualizer using the md document. """ caplog.set_level(logging.INFO) session = Meta.init("postgresql://localhost:5432/" + ATTRIBUTE).Session() PARALLEL = 1 max_docs = 1 docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/md.pdf" # Preprocessor for the Docs doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the md document corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) # Grab the md document doc = session.query(Document).order_by(Document.name).all()[0] assert doc.name == "md" organization_ngrams = MentionNgrams(n_max=1) Org = mention_subclass("Org") organization_matcher = OrganizationMatcher() mention_extractor = MentionExtractor(session, [Org], [organization_ngrams], [organization_matcher]) mention_extractor.apply([doc], parallelism=PARALLEL) Organization = candidate_subclass("Organization", [Org]) candidate_extractor = CandidateExtractor(session, [Organization]) candidate_extractor.apply([doc], split=0, parallelism=PARALLEL) cands = session.query(Organization).filter(Organization.split == 0).all() # Test visualizer pdf_path = "tests/data/pdf_simple" vis = Visualizer(pdf_path) vis.display_candidates([cands[0]])
def test_parse_document_diseases(caplog): """Unit test of OmniParser on a single document. This tests both the structural and visual parse of the document. """ caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session() PARALLEL = 2 max_docs = 2 docs_path = 'tests/data/html_simple/' pdf_path = 'tests/data/pdf_simple/' # Preprocessor for the Docs preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs) # Create an OmniParser and parse the diseases document omni = OmniParser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) omni.apply(preprocessor, parallelism=PARALLEL) # Grab the diseases document doc = session.query(Document).order_by(Document.name).all()[0] logger.info("Doc: {}".format(doc)) for phrase in doc.phrases: logger.info(" Phrase: {}".format(phrase.text)) phrase = sorted(doc.phrases)[11] logger.info(" {}".format(phrase)) # Test structural attributes assert phrase.xpath == '/html/body/table[1]/tbody/tr[3]/td[1]/p' assert phrase.html_tag == 'p' assert phrase.html_attrs == ['class=s6', 'style=padding-top: 1pt'] # Test visual attributes assert phrase.page == [1, 1, 1] assert phrase.top == [342, 296, 356] assert phrase.left == [318, 369, 318] # Test lingual attributes assert phrase.ner_tags == ['O', 'O', 'GPE'] assert phrase.dep_labels == ['ROOT', 'prep', 'pobj'] # 44 phrases expected in the "diseases" document. assert len(doc.phrases) == 36
def test_simple_tokenizer(caplog): """Unit test of Parser on a single document with lingual features off.""" caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session() # SpaCy on mac has issue on parallel parseing if os.name == "posix": PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 2 docs_path = "tests/data/html_simple/" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) parser = Parser(structural=True, lingual=False, visual=True, pdf_path=pdf_path) parser.apply(preprocessor, parallelism=PARALLEL) doc = session.query(Document).order_by(Document.name).all()[1] logger.info("Doc: {}".format(doc)) for i, sentence in enumerate(doc.sentences): logger.info(" Sentence[{}]: {}".format(i, sentence.text)) header = sorted(doc.sentences, key=lambda x: x.position)[0] # Test structural attributes assert header.xpath == "/html/body/h1" assert header.html_tag == "h1" assert header.html_attrs == ["id=sample-markdown"] # Test lingual attributes assert header.ner_tags == ["", ""] assert header.dep_labels == ["", ""] assert header.dep_parents == [0, 0] assert header.lemmas == ["", ""] assert header.pos_tags == ["", ""] assert len(doc.sentences) == 44
def test_meta_connection_strings(database_session): """Simple sanity checks for validating postgres connection strings.""" with pytest.raises(ValueError): Meta.init("postgresql" + DB).Session() with pytest.raises(ValueError): Meta.init("sqlite://somethingsilly" + DB).Session() with pytest.raises(OperationalError): Meta.init("postgresql://somethingsilly:5432/").Session() session = Meta.init("postgresql://localhost:5432/" + DB).Session() engine = session.get_bind() session.close() engine.dispose() assert Meta.DBNAME == DB
def test_ngrams(caplog): """Test ngram limits in mention extraction""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 503 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction Person = mention_subclass("Person") person_ngrams = MentionNgrams(n_max=3) person_matcher = PersonMatcher() mention_extractor = MentionExtractor( session, [Person], [person_ngrams], [person_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Person).count() == 118 mentions = session.query(Person).all() assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0 # Test for unigram exclusion person_ngrams = MentionNgrams(n_min=2, n_max=3) mention_extractor = MentionExtractor( session, [Person], [person_ngrams], [person_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Person).count() == 69 mentions = session.query(Person).all() assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
def test_preprocessor_parse_file_called_once(mocker): """Test if DocPreprocessor._parse_file is called only once during parser.apply.""" max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" # Set up preprocessor, parser, and spy on preprocessor doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) spy = mocker.spy(doc_preprocessor, "_parse_file") corpus_parser = Parser(session) # Check if udf.last_docs is empty. assert len(corpus_parser.get_last_documents()) == 0 # Parsing corpus_parser.apply(doc_preprocessor) # Check if udf.last_docs is correctly updated. assert len(corpus_parser.get_last_documents()) == max_docs # doc_preprocessor._parse_file should be called only once (#434). spy.assert_called_once()
def test_row_col_ngram_extraction(caplog): """Test whether row/column ngrams list is empty, if mention is not in a table.""" caplog.set_level(logging.INFO) PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction place_ngrams = MentionNgramsTemp(n_max=4) Place = mention_subclass("Place") def get_row_and_column_ngrams(mention): row_ngrams = list(get_row_ngrams(mention)) col_ngrams = list(get_col_ngrams(mention)) if not mention.sentence.is_tabular(): assert len(row_ngrams) == 1 and row_ngrams[0] is None assert len(col_ngrams) == 1 and col_ngrams[0] is None else: assert not any(x is None for x in row_ngrams) assert not any(x is None for x in col_ngrams) if "birth_place" in row_ngrams: return True else: return False birthplace_matcher = LambdaFunctionMatcher(func=get_row_and_column_ngrams) mention_extractor = MentionExtractor( session, [Place], [place_ngrams], [birthplace_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL)
def test_simple_tokenizer(caplog): """Unit test of OmniParser on a single document with lingual features off. """ caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session() PARALLEL = 2 max_docs = 2 docs_path = 'tests/data/html_simple/' pdf_path = 'tests/data/pdf_simple/' # Preprocessor for the Docs preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs) omni = OmniParser( structural=True, lingual=False, visual=True, pdf_path=pdf_path) omni.apply(preprocessor, parallelism=PARALLEL) doc = session.query(Document).order_by(Document.name).all()[1] logger.info("Doc: {}".format(doc)) for i, phrase in enumerate(doc.phrases): logger.info(" Phrase[{}]: {}".format(i, phrase.text)) header = doc.phrases[0] # Test structural attributes assert header.xpath == '/html/body/h1' assert header.html_tag == 'h1' assert header.html_attrs == ['id=sample-markdown'] # Test lingual attributes assert header.ner_tags == ['', ''] assert header.dep_labels == ['', ''] assert header.dep_parents == [0, 0] assert header.lemmas == ['', ''] assert header.pos_tags == ['', ''] assert len(doc.phrases) == 44
def test_warning_on_incorrect_filename(caplog): """Test that a warning is issued on invalid pdf.""" caplog.set_level(logging.INFO) session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session() PARALLEL = 1 docs_path = "tests/data/html_simple/md_para.html" pdf_path = "tests/data/html_simple/md_para.html" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) # Create an Parser and parse the md document parser = Parser(structural=True, tabular=True, lingual=True, visual=True, pdf_path=pdf_path) with pytest.warns(RuntimeWarning): parser.apply(preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == 1
def test_parse_style(caplog): """Test style tag parsing.""" caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session() max_docs = 1 docs_path = 'tests/data/html_extended/ext_diseases.html' pdf_path = 'tests/data/pdf_extended/ext_diseases.pdf' # Preprocessor for the Docs preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs) # Grab the document, text tuple from the preprocessor doc, text = next(preprocessor.generate()) logger.info(" Text: {}".format(text)) # Create an OmniParserUDF omni_udf = OmniParserUDF( True, # structural [], # blacklist, empty so that style is not blacklisted ["span", "br"], # flatten '', # flatten delim True, # lingual True, # strip [], # replace True, # tabular True, # visual pdf_path, # pdf path Spacy()) # lingual parser # Grab the phrases parsed by the OmniParser phrases = list(omni_udf.parse_structure(doc, text)) logger.warning("Doc: {}".format(doc)) for phrase in phrases: logger.warning(" Phrase: {}".format(phrase.html_attrs)) # Phrases for testing sub_phrases = [ { 'index': 7, 'attr': [ 'class=col-header', 'hobbies=work:hard;play:harder', 'type=phenotype', 'style=background: #f1f1f1; color: aquamarine; font-size: 18px;' ] }, { 'index': 10, 'attr': ['class=row-header', 'style=background: #f1f1f1;'] }, { 'index': 12, 'attr': ['class=cell', 'style=text-align: center;'] } ] # Assertions assert(all(phrases[p['index']].html_attrs == p['attr'] for p in sub_phrases))
def test_mention_longest_match(caplog): """Test longest match filtering in mention extraction.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parsing PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction name_ngrams = MentionNgramsPart(n_max=3) place_ngrams = MentionNgramsTemp(n_max=4) Name = mention_subclass("Name") Place = mention_subclass("Place") def is_birthplace_table_row(mention): if not mention.sentence.is_tabular(): return False ngrams = get_row_ngrams(mention, lower=True) if "birth_place" in ngrams: return True else: return False birthplace_matcher = LambdaFunctionMatcher( func=is_birthplace_table_row, longest_match_only=False ) mention_extractor = MentionExtractor( session, [Name, Place], [name_ngrams, place_ngrams], [PersonMatcher(), birthplace_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) mentions = session.query(Place).all() mention_spans = [x.context.get_span() for x in mentions] assert "Sinking Spring Farm" in mention_spans assert "Farm" in mention_spans assert len(mention_spans) == 23 birthplace_matcher = LambdaFunctionMatcher( func=is_birthplace_table_row, longest_match_only=True ) mention_extractor = MentionExtractor( session, [Name, Place], [name_ngrams, place_ngrams], [PersonMatcher(), birthplace_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) mentions = session.query(Place).all() mention_spans = [x.context.get_span() for x in mentions] assert "Sinking Spring Farm" in mention_spans assert "Farm" not in mention_spans assert len(mention_spans) == 4
def test_cand_gen(caplog): """Test extracting candidates from mentions from documents.""" caplog.set_level(logging.INFO) if platform == "darwin": logger.info("Using single core.") PARALLEL = 1 else: logger.info("Using two cores.") PARALLEL = 2 # Travis only gives 2 cores def do_nothing_matcher(fig): return True max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 70 assert session.query(Volt).count() == 33 assert session.query(Temp).count() == 23 assert session.query(Fig).count() == 31 part = session.query(Part).order_by(Part.id).all()[0] volt = session.query(Volt).order_by(Volt.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Volt: {volt.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[ temp_throttler, volt_throttler, volt_throttler, ], # Fail, mismatched arity ) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp], # Fail, mismatched arity throttlers=[temp_throttler, volt_throttler], ) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt] ) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1610 assert session.query(PartVolt).count() == 2310 assert session.query(Candidate).count() == 3920 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(PartVolt).count() == 0 # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, None] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(PartVolt).count() == 2310 assert session.query(Candidate).count() == 3742 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(PartVolt).count() == 1993 assert session.query(Candidate).count() == 3425 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].volts) == 33 assert len(docs[0].temps) == 23 # Test that deletion of a Candidate does not delete the Mention session.query(PartTemp).delete(synchronize_session="fetch") assert session.query(PartTemp).count() == 0 assert session.query(Temp).count() == 23 assert session.query(Part).count() == 70 # Test deletion of Candidate if Mention is deleted assert session.query(PartVolt).count() == 1993 assert session.query(Volt).count() == 33 session.query(Volt).delete(synchronize_session="fetch") assert session.query(Volt).count() == 0 assert session.query(PartVolt).count() == 0
def test_cand_gen_cascading_delete(caplog): """Test cascading the deletion of candidates.""" caplog.set_level(logging.INFO) if platform == "darwin": logger.info("Using single core.") PARALLEL = 1 else: logger.info("Using two cores.") PARALLEL = 2 # Travis only gives 2 cores max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Mention).count() == 93 assert session.query(Part).count() == 70 assert session.query(Temp).count() == 23 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(Candidate).count() == 1432 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].temps) == 23 # Delete from parent class should cascade to child x = session.query(Candidate).first() session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(Candidate).count() == 1431 assert session.query(PartTemp).count() == 1431 # Clearing Mentions should also delete Candidates mention_extractor.clear() assert session.query(Mention).count() == 0 assert session.query(Part).count() == 0 assert session.query(Temp).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(Candidate).count() == 0
def test_cand_gen_cascading_delete(): """Test cascading the deletion of candidates.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Mention).count() == 93 assert session.query(Part).count() == 70 assert session.query(Temp).count() == 23 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(Candidate).count() == 1432 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].temps) == 23 # Delete from parent class should cascade to child x = session.query(Candidate).first() session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(Candidate).count() == 1431 assert session.query(PartTemp).count() == 1431 # Test that deletion of a Candidate does not delete the Mention x = session.query(PartTemp).first() session.query(PartTemp).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(PartTemp).count() == 1430 assert session.query(Temp).count() == 23 assert session.query(Part).count() == 70 # Clearing Mentions should also delete Candidates mention_extractor.clear() assert session.query(Mention).count() == 0 assert session.query(Part).count() == 0 assert session.query(Temp).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(Candidate).count() == 0
def test_feature_extraction(): """Test extracting candidates from mentions from documents.""" PARALLEL = 1 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgrams(n_max=1) temp_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert docs[0].name == "112823" assert session.query(Part).count() == 58 assert session.query(Temp).count() == 16 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) n_cands = session.query(PartTemp).count() # Featurization based on default feature library featurizer = Featurizer(session, [PartTemp]) # Test that featurization default feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_feats = session.query(FeatureKey).count() featurizer.clear(train=True) # Example feature extractor def feat_ext(candidates): candidates = candidates if isinstance(candidates, list) else [candidates] for candidate in candidates: yield candidate.id, f"cand_id_{candidate.id}", 1 # Featurization with one extra feature extractor feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_w_customized_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization textual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_textual_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization tabular feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_tabular_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization structural feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_structural_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization visual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_visual_features = session.query(FeatureKey).count() featurizer.clear(train=True) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features) assert n_default_w_customized_features == n_default_feats + n_cands
def test_incremental(caplog): """Run an end-to-end test on incremental additions.""" caplog.set_level(logging.INFO) PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/dtc114w.html" pdf_path = "tests/data/pdf/dtc114w.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs docs = corpus_parser.get_documents() last_docs = corpus_parser.get_documents() assert len(docs[0].sentences) == len(last_docs[0].sentences) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp], throttlers=[temp_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70 assert session.query(Candidate).count() == 70 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 70 # Featurization featurizer = Featurizer(session, [PartTemp]) featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 70 assert session.query(FeatureKey).count() == 512 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (70, 512) assert len(featurizer.get_keys()) == 512 # Test Dropping FeatureKey featurizer.drop_keys(["CORE_e1_LENGTH_1"]) assert session.query(FeatureKey).count() == 512 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 70 # Only 5 because LF_operating_row doesn't apply to the first test doc assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (70, 5) assert len(labeler.get_keys()) == 5 docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False) assert len(corpus_parser.get_documents()) == 2 new_docs = corpus_parser.get_last_documents() assert len(new_docs) == 1 assert new_docs[0].name == "112823" # Get mentions from just the new docs mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Just run candidate extraction and assign to split 0 candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1502 # Update features featurizer.update(new_docs, parallelism=PARALLEL) assert session.query(Feature).count() == 1502 assert session.query(FeatureKey).count() == 2573 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (1502, 2573) assert len(featurizer.get_keys()) == 2573 # Update Labels labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 1502 assert session.query(LabelKey).count() == 6 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (1502, 6) # Test clear featurizer.clear(train=True) assert session.query(FeatureKey).count() == 0