def test_pdf_word_list_is_sorted(): """Test if pdf_word_list is sorted as expected. no_image_unsorted.html is originally created from pdf_simple/no_image.pdf, but the order of html elements like block and word has been changed to see if pdf_word_list is sorted as expected. """ docs_path = "tests/data/html_simple/no_image_unsorted.html" pdf_path = "tests/data/pdf_simple" visual_parser = PdfVisualParser(pdf_path=pdf_path) with open(docs_path) as f: soup = BeautifulSoup(f, "html.parser") page = soup.find_all("page")[0] pdf_word_list, coordinate_map = visual_parser._coordinates_from_HTML( page, 1) # Check if words are sorted by block top assert set([content for (_, content) in pdf_word_list[:2]]) == {"Sample", "HTML"} # Check if words are sorted by top assert [content for (_, content) in pdf_word_list[2:7]] == [ "This", "is", "an", "html", "that", ] # Check if words are sorted by left (#449) assert [content for (_, content) in pdf_word_list[:2]] == ["Sample", "HTML"]
def test_parse_wo_tabular(): """Test the parser without extracting tabular information.""" docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "md")) # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=False, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), language="en", ) doc = parser_udf.apply(doc) # Check that doc has neither table nor cell assert len(doc.sections) == 1 assert len(doc.paragraphs) == 44 assert len(doc.figures) == 1 assert len(doc.tables) == 0 assert len(doc.cells) == 0 assert len(doc.sentences) == 45 # Check that sentences are associated with both section and paragraph. assert all([sent.section is not None for sent in doc.sentences]) assert all([sent.paragraph is not None for sent in doc.sentences]) # Check that sentences are NOT associated with cell assert all([sent.cell is None for sent in doc.sentences])
def test_visual_parser_not_affected_by_order_of_sentences(): """Test if visual_parser result is not affected by the order of sentences.""" docs_path = "tests/data/html/2N6427.html" pdf_path = "tests/data/pdf/" # Initialize preprocessor, parser, visual_parser. # Note that parser is initialized with `visual=False` and that visual_parser # will be used to attach "visual" information to sentences after parsing. preprocessor = HTMLDocPreprocessor(docs_path) parser_udf = get_parser_udf(structural=True, lingual=False, tabular=True, visual=False) visual_parser = PdfVisualParser(pdf_path=pdf_path) doc = parser_udf.apply(next(preprocessor.__iter__())) # Sort sentences by sentence.position doc.sentences = sorted(doc.sentences, key=attrgetter("position")) sentences0 = [ sent for sent in visual_parser.parse(doc.name, doc.sentences) ] # Sort again in case visual_parser.link changes the order sentences0 = sorted(sentences0, key=attrgetter("position")) doc = parser_udf.apply(next(preprocessor.__iter__())) # Shuffle random.shuffle(doc.sentences) sentences1 = [ sent for sent in visual_parser.parse(doc.name, doc.sentences) ] # Sort sentences by sentence.position sentences1 = sorted(sentences1, key=attrgetter("position")) # This should hold as both sentences are sorted by their position assert all([ sent0.position == sent1.position for (sent0, sent1) in zip(sentences0, sentences1) ]) # The following assertion should hold if the visual_parser result is not affected # by the order of sentences. assert all([ sent0.left == sent1.left for (sent0, sent1) in zip(sentences0, sentences1) ])
def test_parse_style(): """Test style tag parsing.""" logger = logging.getLogger(__name__) docs_path = "tests/data/html_extended/ext_diseases.html" pdf_path = "tests/data/pdf_extended/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "ext_diseases")) # Create an Parser and parse the diseases document parser_udf = get_parser_udf( structural=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), ) doc = parser_udf.apply(doc) # Grab the sentences parsed by the Parser sentences = doc.sentences logger.warning(f"Doc: {doc}") for i, sentence in enumerate(sentences): logger.warning(f" Sentence[{i}]: {sentence.html_attrs}") # sentences for testing sub_sentences = [ { "index": 6, "attr": [ "class=col-header", "hobbies=work:hard;play:harder", "type=phenotype", "style=background: #f1f1f1; color: aquamarine; font-size: 18px;", ], }, { "index": 9, "attr": ["class=row-header", "style=background: #f1f1f1;"] }, { "index": 11, "attr": ["class=cell", "style=text-align: center;"] }, ] # Assertions assert all(sentences[p["index"]].html_attrs == p["attr"] for p in sub_sentences)
def parse(session: Session, docs_path: str, pdf_path: str) -> List[Document]: """Parse documents using Parser UDF Runner.""" # Preprocessor for the Docs doc_preprocessor = HTMLDocPreprocessor(docs_path) # Create an Parser and parse the documents corpus_parser = Parser( session, parallelism=1, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.clear() corpus_parser.apply(doc_preprocessor) return corpus_parser.get_documents()
def parse_doc(docs_path: str, file_name: str, pdf_path: Optional[str] = None): """Parse documents from given path.""" max_docs = 1 logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) doc = next(doc_preprocessor._parse_file(docs_path, file_name)) # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=True, lingual=True, visual=True if pdf_path else False, visual_parser=PdfVisualParser(pdf_path) if pdf_path else None, language="en", ) doc = parser_udf.apply(doc) return doc
def test_simple_parser(): """Unit test of Parser on a single document with lingual features off.""" logger = logging.getLogger(__name__) docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "md")) # Check that doc has a name assert doc.name == "md" # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, lingual=False, visual=True, visual_parser=PdfVisualParser(pdf_path), lingual_parser=SimpleParser(delim="NoDelim"), ) doc = parser_udf.apply(doc) logger.info(f"Doc: {doc}") for i, sentence in enumerate(doc.sentences): logger.info(f" Sentence[{i}]: {sentence.text}") header = sorted(doc.sentences, key=lambda x: x.position)[0] # Test structural attributes assert header.xpath == "/html/body/h1" assert header.html_tag == "h1" assert header.html_attrs == ["id=sample-markdown"] # Test lingual attributes assert header.ner_tags == ["", ""] assert header.dep_labels == ["", ""] assert header.dep_parents == [0, 0] assert header.lemmas == ["", ""] assert header.pos_tags == ["", ""] assert len(doc.sentences) == 44
def test_warning_on_incorrect_filename(): """Test that a warning is issued on invalid pdf.""" docs_path = "tests/data/html_simple/md_para.html" pdf_path = "tests/data/html_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "md_para")) # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), ) with pytest.warns(RuntimeWarning) as record: doc = parser_udf.apply(doc) assert isinstance(record, type(pytest.warns(RuntimeWarning)))
def test_warning_on_missing_pdf(): """Test that a warning is issued on invalid pdf.""" docs_path = "tests/data/html_simple/table_span.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(iter(preprocessor)) # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), ) with pytest.warns(RuntimeWarning) as record: doc = parser_udf.apply(doc) assert len(record) == 1 assert "Visual parse failed" in record[0].message.args[0]
def test_parser_no_image(): """Unit test of Parser on a single document that has a figure without image.""" docs_path = "tests/data/html_simple/no_image.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "no_image")) # Check that doc has a name assert doc.name == "no_image" # Create an Parser and parse the no_image document parser_udf = get_parser_udf( structural=True, lingual=False, visual=True, visual_parser=PdfVisualParser(pdf_path), ) doc = parser_udf.apply(doc) # Check that doc has no figures assert len(doc.figures) == 0
def mention_setup(): """Set up mentions.""" docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor.__iter__()) # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), language="en", ) doc = parser_udf.apply(doc) # Create 1-gram span mentions space = MentionNgrams(n_min=1, n_max=1) mentions = [tc for tc in space.apply(doc)] return mentions
def test_parse_md_details(): """Test the parser with the md document.""" logger = logging.getLogger(__name__) docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "md")) # Check that doc has a name assert doc.name == "md" # Check that doc does not have any of these assert len(doc.figures) == 0 assert len(doc.tables) == 0 assert len(doc.cells) == 0 assert len(doc.sentences) == 0 # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), language="en", ) doc = parser_udf.apply(doc) # Check that doc has a figure assert len(doc.figures) == 1 assert doc.figures[0].url == "http://placebear.com/200/200" assert doc.figures[0].position == 0 assert doc.figures[0].section.position == 0 assert doc.figures[0].stable_id == "md::figure:0" # Check that doc has a table assert len(doc.tables) == 1 assert doc.tables[0].position == 0 assert doc.tables[0].section.position == 0 assert doc.tables[0].document.name == "md" # Check that doc has cells assert len(doc.cells) == 16 cells = list(doc.cells) assert cells[0].row_start == 0 assert cells[0].col_start == 0 assert cells[0].position == 0 assert cells[0].document.name == "md" assert cells[0].table.position == 0 assert cells[10].row_start == 2 assert cells[10].col_start == 2 assert cells[10].position == 10 assert cells[10].document.name == "md" assert cells[10].table.position == 0 # Check that doc has sentences assert len(doc.sentences) == 45 sent = sorted(doc.sentences, key=lambda x: x.position)[25] assert sent.text == "Spicy" assert sent.table.position == 0 assert sent.table.section.position == 0 assert sent.cell.row_start == 0 assert sent.cell.col_start == 2 # Check if the tail text is processed after inner elements (#333) assert [sent.text for sent in doc.sentences[14:18]] == [ "italics and later", "bold", ".", "Even", ] # Check abs_char_offsets (#332) text = "".join([sent.text for sent in doc.sentences]) for sent in doc.sentences: for abs_char_offset, word in zip(sent.abs_char_offsets, sent.words): assert text[abs_char_offset] == word[0] logger.info(f"Doc: {doc}") for i, sentence in enumerate(doc.sentences): logger.info(f" Sentence[{i}]: {sentence.text}") header = sorted(doc.sentences, key=lambda x: x.position)[0] # Test structural attributes assert header.xpath == "/html/body/h1" assert header.html_tag == "h1" assert header.html_attrs == ["id=sample-markdown"] # Test visual attributes assert header.page == [1, 1] assert header.top == [35, 35] assert header.bottom == [61, 61] assert header.right == [111, 231] assert header.left == [35, 117] # Test lingual attributes # when lingual=True, some value other than "" should be filled-in. assert all(sent.ner_tags) assert all(sent.dep_labels) # Test whether nlp information corresponds to sentence words for sent in doc.sentences: assert len(sent.words) == len(sent.lemmas) assert len(sent.words) == len(sent.pos_tags) assert len(sent.words) == len(sent.ner_tags) assert len(sent.words) == len(sent.dep_parents) assert len(sent.words) == len(sent.dep_labels)
def test_parse_document_diseases(): """Unit test of Parser on a single document. This tests both the structural and visual parse of the document. """ logger = logging.getLogger(__name__) docs_path = "tests/data/html_simple/diseases.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "diseases")) # Check that doc has a name assert doc.name == "diseases" # Create an Parser and parse the diseases document parser_udf = get_parser_udf( structural=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), ) doc = parser_udf.apply(doc) logger.info(f"Doc: {doc}") for sentence in doc.sentences: logger.info(f" Sentence: {sentence.text}") # Check captions assert len(doc.captions) == 2 caption = sorted(doc.sentences, key=lambda x: x.position)[20] assert caption.paragraph.caption.position == 0 assert caption.paragraph.caption.table.position == 0 assert caption.text == "Table 1: Infectious diseases and where to find them." assert caption.paragraph.position == 18 # Check figures assert len(doc.figures) == 0 # Check that doc has a table assert len(doc.tables) == 3 assert doc.tables[0].position == 0 assert doc.tables[0].document.name == "diseases" # Check that doc has cells assert len(doc.cells) == 25 sentence = sorted(doc.sentences, key=lambda x: x.position)[10] logger.info(f" {sentence}") # Check the sentence's cell assert sentence.table.position == 0 assert sentence.cell.row_start == 2 assert sentence.cell.col_start == 1 assert sentence.cell.position == 4 # Test structural attributes assert sentence.xpath == "/html/body/table[1]/tbody/tr[3]/td[1]/p" assert sentence.html_tag == "p" assert sentence.html_attrs == ["class=s6", "style=padding-top: 1pt"] # Test visual attributes assert sentence.page == [1, 1, 1] assert sentence.top == [342, 296, 356] assert sentence.left == [318, 369, 318] # Test lingual attributes # when lingual=True, some value other than "" should be filled-in. assert all(sentence.ner_tags) assert all(sentence.dep_labels) assert len(doc.sentences) == 37
def test_incremental(database_session): """Run an end-to-end test on incremental additions.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 1 session = database_session docs_path = "tests/data/html/dtc114w.html" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.apply(doc_preprocessor) num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs docs = corpus_parser.get_documents() last_docs = corpus_parser.get_documents() assert len(docs[0].sentences) == len(last_docs[0].sentences) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Test if clear=True works mention_extractor.apply(docs, parallelism=PARALLEL, clear=True) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Candidate Extraction candidate_extractor = CandidateExtractor(session, [PartTemp], throttlers=[temp_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70 assert session.query(Candidate).count() == 70 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 70 # Featurization featurizer = Featurizer(session, [PartTemp]) featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == len(train_cands[0]) num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 514 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (len(train_cands[0]), num_feature_keys) assert len(featurizer.get_keys()) == num_feature_keys # Test Dropping FeatureKey featurizer.drop_keys(["BASIC_e1_LENGTH_1"]) assert session.query(FeatureKey).count() == num_feature_keys - 1 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == len(train_cands[0]) # Only 5 because LF_operating_row doesn't apply to the first test doc assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), 5) assert len(labeler.get_keys()) == 5 docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, clear=False) assert len(corpus_parser.get_documents()) == 2 new_docs = corpus_parser.get_last_documents() assert len(new_docs) == 1 assert new_docs[0].name == "112823" # Get mentions from just the new docs mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Test if existing mentions are skipped. mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Just run candidate extraction and assign to split 0 candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1501 # Test if existing candidates are skipped. candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1501 # Update features featurizer.update(new_docs, parallelism=PARALLEL) assert session.query(Feature).count() == len(train_cands[0]) num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 2608 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (len(train_cands[0]), num_feature_keys) assert len(featurizer.get_keys()) == num_feature_keys # Update LF_storage_row. Now it always returns ABSTAIN. @labeling_function(name="LF_storage_row") def LF_storage_row_updated(c): return ABSTAIN stg_temp_lfs = [ LF_storage_row_updated, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] # Update Labels labeler.update(docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) assert session.query(Label).count() == len(train_cands[0]) # Only 5 because LF_storage_row doesn't apply to any doc (always ABSTAIN) num_label_keys = session.query(LabelKey).count() assert num_label_keys == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), num_label_keys) # Test clear featurizer.clear(train=True) assert session.query(FeatureKey).count() == 0
def test_too_many_clients_error_should_not_happen(database_session): """Too many clients error should not happens.""" PARALLEL = 32 logger.info("Parallel: {PARALLEL}") def do_nothing_matcher(fig): return True max_docs = 1 session = database_session docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt] ) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) candidate_extractor.clear_all(split=0) # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, None] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
def test_parse_md_paragraphs(): """Unit test of Paragraph parsing.""" docs_path = "tests/data/html_simple/md_para.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor._parse_file(docs_path, "md_para")) # Check that doc has a name assert doc.name == "md_para" # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), ) doc = parser_udf.apply(doc) # Check that doc has a figure assert len(doc.figures) == 6 assert doc.figures[0].url == "http://placebear.com/200/200" assert doc.figures[0].position == 0 assert doc.figures[0].section.position == 0 assert len(doc.figures[0].captions) == 0 assert doc.figures[0].stable_id == "md_para::figure:0" assert doc.figures[0].cell.position == 13 assert (doc.figures[2].url == "http://html5doctor.com/wp-content/uploads/2010/03/kookaburra.jpg") assert doc.figures[2].position == 2 assert len(doc.figures[2].captions) == 1 assert len(doc.figures[2].captions[0].paragraphs[0].sentences) == 3 assert (doc.figures[2].captions[0].paragraphs[0].sentences[0].text == "Australian Birds.") assert len(doc.figures[4].captions) == 0 assert (doc.figures[4].url == "http://html5doctor.com/wp-content/uploads/2010/03/pelican.jpg") # Check that doc has a table assert len(doc.tables) == 1 assert doc.tables[0].position == 0 assert doc.tables[0].section.position == 0 # Check that doc has cells assert len(doc.cells) == 16 cells = list(doc.cells) assert cells[0].row_start == 0 assert cells[0].col_start == 0 assert cells[0].position == 0 assert cells[0].table.position == 0 assert cells[10].row_start == 2 assert cells[10].col_start == 2 assert cells[10].position == 10 assert cells[10].table.position == 0 # Check that doc has sentences assert len(doc.sentences) == 51 sentences = sorted(doc.sentences, key=lambda x: x.position) sent1 = sentences[1] sent2 = sentences[2] sent3 = sentences[3] assert sent1.text == "This is some basic, sample markdown." assert sent2.text == ( "Unlike the other markdown document, however, " "this document actually contains paragraphs of text.") assert sent1.paragraph.position == 1 assert sent1.section.position == 0 assert sent2.paragraph.position == 1 assert sent2.section.position == 0 assert sent3.paragraph.position == 1 assert sent3.section.position == 0 assert len(doc.paragraphs) == 46 assert len(doc.paragraphs[1].sentences) == 3 assert len(doc.paragraphs[2].sentences) == 1
def test_cand_gen_cascading_delete(database_session): """Test cascading the deletion of candidates.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 1 session = database_session docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Mention).count() == 93 assert session.query(Part).count() == 70 assert session.query(Temp).count() == 23 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1431 assert session.query(Candidate).count() == 1431 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].temps) == 23 # Delete from parent class should cascade to child x = session.query(Candidate).first() session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(Candidate).count() == 1430 assert session.query(PartTemp).count() == 1430 # Test that deletion of a Candidate does not delete the Mention x = session.query(PartTemp).first() candidate = session.query(PartTemp).filter_by(id=x.id).first() session.delete(candidate) assert session.query(PartTemp).count() == 1429 assert session.query(Temp).count() == 23 assert session.query(Part).count() == 70 # Clearing Mentions should also delete Candidates mention_extractor.clear() assert session.query(Mention).count() == 0 assert session.query(Part).count() == 0 assert session.query(Temp).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(Candidate).count() == 0
def test_non_existent_pdf_path_should_fail(): """Test if a non-existent raises an error.""" pdf_path = "dummy_path" with pytest.raises(ValueError): PdfVisualParser(pdf_path=pdf_path)
def test_e2e(database_session): """Run an end-to-end test on documents of the hardware domain.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 12 fonduer.init_logging( format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s", level=logging.INFO, ) session = database_session docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info(f"Sentences: {num_sentences}") # Divide into test and train docs = sorted(corpus_parser.get_documents()) last_docs = sorted(corpus_parser.get_last_documents()) ld = len(docs) assert ld == len(last_docs) assert len(docs[0].sentences) == len(last_docs[0].sentences) train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert len(mention_extractor.get_mentions()) == 3 assert len(mention_extractor.get_mentions(docs)) == 3 # Candidate Extraction candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) # Candidate lists should be deterministically sorted. assert ("112823::implicit_span_mention:11059:11065:part_expander:0" == train_cands[0][0][0].context.get_stable_id()) assert ("112823::implicit_span_mention:2752:2754:temp_expander:0" == train_cands[0][0][1].context.get_stable_id()) assert len(train_cands) == 2 assert len(candidate_extractor.get_candidates(docs)) == 2 # Featurization featurizer = Featurizer(session, [PartTemp, PartVolt]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 214 num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 1278 # Test Dropping FeatureKey # Should force a row deletion featurizer.drop_keys(["BASIC_e0_CONTAINS_WORDS_[BC182]"]) assert session.query(FeatureKey).count() == num_feature_keys - 1 # Should only remove the part_volt as a relation and leave part_temp assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) assert session.query(FeatureKey).filter( FeatureKey.name == "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"] assert session.query(FeatureKey).count() == num_feature_keys - 1 # Inserting the removed key featurizer.upsert_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt]) assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } assert session.query(FeatureKey).count() == num_feature_keys - 1 # Removing the key again featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) # Removing the last relation from a key should delete the row featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp]) assert session.query(FeatureKey).count() == num_feature_keys - 2 session.query(Feature).delete(synchronize_session="fetch") session.query(FeatureKey).delete(synchronize_session="fetch") featurizer.apply(split=0, train=True, parallelism=PARALLEL) # the number of Features should equals to the total number of train candidates num_features = session.query(Feature).count() assert num_features == len(train_cands[0]) + len(train_cands[1]) num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 4629 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (len(train_cands[0]), num_feature_keys) assert F_train[1].shape == (len(train_cands[1]), num_feature_keys) assert len(featurizer.get_keys()) == num_feature_keys featurizer.apply(split=1, parallelism=PARALLEL) # the number of Features should increate by the total number of dev candidates num_features += len(dev_cands[0]) + len(dev_cands[1]) assert session.query(Feature).count() == num_features assert session.query(FeatureKey).count() == num_feature_keys F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (len(dev_cands[0]), num_feature_keys) assert F_dev[1].shape == (len(dev_cands[1]), num_feature_keys) featurizer.apply(split=2, parallelism=PARALLEL) # the number of Features should increate by the total number of test candidates num_features += len(test_cands[0]) + len(test_cands[1]) assert session.query(Feature).count() == num_features assert session.query(FeatureKey).count() == num_feature_keys F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (len(test_cands[0]), num_feature_keys) assert F_test[1].shape == (len(test_cands[1]), num_feature_keys) gold_file = "tests/data/hardware_tutorial_gold.csv" labeler = Labeler(session, [PartTemp, PartVolt]) # This should raise an error, since gold labels are not yet loaded. with pytest.raises(ValueError): _ = labeler.get_gold_labels(train_cands, annotator="gold") labeler.apply( docs=last_docs, lfs=[[gold], [gold]], table=GoldLabel, train=True, parallelism=PARALLEL, ) # All candidates should now be gold-labeled. assert session.query(GoldLabel).count() == session.query(Candidate).count() stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] ce_v_max_lfs = [ LF_bad_keywords_in_row, LF_current_in_row, LF_non_ce_voltages_in_row, ] with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply( docs=train_docs, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL, ) assert session.query(Label).count() == len(train_cands[0]) + len( train_cands[1]) num_label_keys = session.query(LabelKey).count() assert num_label_keys == 9 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), num_label_keys) assert L_train[1].shape == (len(train_cands[1]), num_label_keys) assert len(labeler.get_keys()) == num_label_keys # Test Dropping LabelerKey labeler.drop_keys(["LF_storage_row"]) assert len(labeler.get_keys()) == num_label_keys - 1 # Test Upserting LabelerKey labeler.upsert_keys(["LF_storage_row"]) assert "LF_storage_row" in [label.name for label in labeler.get_keys()] L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (len(train_cands[0]), 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (len(train_cands[0]), 1) label_model = LabelModel(cardinality=2) label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) # Collect word counter word_counter = collect_word_counter(train_cands) emmental.init(fonduer.Meta.log_path) # Training config config = { "meta_config": { "verbose": False }, "model_config": { "model_path": None, "device": 0, "dataparallel": False }, "learner_config": { "n_epochs": 5, "optimizer_config": { "lr": 0.001, "l2": 0.0 }, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": { f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min" }, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.Meta.update_config(config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=300, specials=specials) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0], emb_layer.word2id, 2), split="test", batch_size=100, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6) true_pred = [test_cands[0][_] for _ in positive[0]] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL) assert session.query(Label).count() == len(train_cands[0]) + len( train_cands[1]) num_label_keys = session.query(LabelKey).count() assert num_label_keys == 16 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), num_label_keys) label_model = LabelModel(cardinality=2) label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) valid_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, np.argmax(train_marginals, axis=1), train_idxs, ), split="valid", batch_size=100, shuffle=False, ) # Testing STL LogisticRegression emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task( ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression", mode="STL", ) model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader, valid_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing STL LSTM emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LSTM", mode="STL") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing MTL LogisticRegression emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task( ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression", mode="MTL", ) model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader, valid_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing LSTM emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LSTM", mode="MTL") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7