def test_cand_gen_cascading_delete(database_session): """Test cascading the deletion of candidates.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 1 session = database_session docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Mention).count() == 93 assert session.query(Part).count() == 70 assert session.query(Temp).count() == 23 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1431 assert session.query(Candidate).count() == 1431 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].temps) == 23 # Delete from parent class should cascade to child x = session.query(Candidate).first() session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(Candidate).count() == 1430 assert session.query(PartTemp).count() == 1430 # Test that deletion of a Candidate does not delete the Mention x = session.query(PartTemp).first() candidate = session.query(PartTemp).filter_by(id=x.id).first() session.delete(candidate) assert session.query(PartTemp).count() == 1429 assert session.query(Temp).count() == 23 assert session.query(Part).count() == 70 # Clearing Mentions should also delete Candidates mention_extractor.clear() assert session.query(Mention).count() == 0 assert session.query(Part).count() == 0 assert session.query(Temp).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(Candidate).count() == 0
def parse_dataset( session, dirname, first_time=False, max_docs=float("inf"), parallel=4 ): """Parse the dataset into dev, test, and train. This expects that the data is located in data/dev/, data/test/, data/train/ directories, and each of those contains html/ and pdf/. Also expects that the filenames of the HTML and PDF match. :param session: The database session :param max_docs: The maximum number of documents to parse from each set. Defaults to parsing all documents. :rtype: (all_docs, train_docs, dev_docs, test_docs) """ train_docs = set() dev_docs = set() test_docs = set() if first_time: for division in ["dev", "test", "train"]: logger.info(f"Parsing {division}...") html_path = os.path.join(dirname, f"data/{division}/html/") pdf_path = os.path.join(dirname, f"data/{division}/pdf/") doc_preprocessor = HTMLDocPreprocessor(html_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) # Do not want to clear the database when parsing test and train if division == "dev": corpus_parser.apply(doc_preprocessor, parallelism=parallel) dev_docs = set(corpus_parser.get_last_documents()) if division == "test": corpus_parser.apply(doc_preprocessor, parallelism=parallel, clear=False) test_docs = set(corpus_parser.get_last_documents()) if division == "train": corpus_parser.apply(doc_preprocessor, parallelism=parallel, clear=False) train_docs = set(corpus_parser.get_last_documents()) all_docs = corpus_parser.get_documents() else: logger.info("Reloading pre-parsed dataset.") all_docs = Parser(session).get_documents() for division in ["dev", "test", "train"]: pdf_path = os.path.join(dirname, f"data/{division}/pdf/") if division == "dev": dev_doc_names = _files_in_dir(pdf_path) if division == "test": test_doc_names = _files_in_dir(pdf_path) if division == "train": train_doc_names = _files_in_dir(pdf_path) for doc in all_docs: if doc.name in dev_doc_names: dev_docs.add(doc) if doc.name in test_doc_names: test_docs.add(doc) if doc.name in train_doc_names: train_docs.add(doc) return all_docs, train_docs, dev_docs, test_docs
def test_unary_relation_feature_extraction(): """Test extracting unary candidates from mentions from documents.""" PARALLEL = 1 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") mention_extractor = MentionExtractor(session, [Part], [part_ngrams], [part_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert docs[0].name == "112823" assert session.query(Part).count() == 58 part = session.query(Part).order_by(Part.id).all()[0] logger.info(f"Part: {part.context}") # Candidate Extraction PartRel = candidate_subclass("PartRel", [Part]) candidate_extractor = CandidateExtractor(session, [PartRel]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) # Featurization based on default feature library featurizer = Featurizer(session, [PartRel]) # Test that featurization default feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_feats = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization textual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_textual_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization tabular feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_tabular_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization structural feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_structural_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization visual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_visual_features = session.query(FeatureKey).count() featurizer.clear(train=True) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features)
def test_too_many_clients_error_should_not_happen(database_session): """Too many clients error should not happens.""" PARALLEL = 32 logger.info("Parallel: {PARALLEL}") def do_nothing_matcher(fig): return True max_docs = 1 session = database_session docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt] ) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) candidate_extractor.clear_all(split=0) # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, None] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
def test_cand_gen_cascading_delete(caplog): """Test cascading the deletion of candidates.""" caplog.set_level(logging.INFO) if platform == "darwin": logger.info("Using single core.") PARALLEL = 1 else: logger.info("Using two cores.") PARALLEL = 2 # Travis only gives 2 cores def do_nothing_matcher(fig): return True max_docs = 10 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 5548 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Mention).count() == 370 assert session.query(Part).count() == 234 assert session.query(Temp).count() == 136 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp], throttlers=[temp_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 3879 assert session.query(Candidate).count() == 3879 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].temps) == 24 # Delete from parent class should cascade to child x = session.query(Candidate).first() session.query(Candidate).filter_by(id=x.id).delete( synchronize_session="fetch") assert session.query(PartTemp).count() == 3878 assert session.query(Candidate).count() == 3878 # Clearing Mentions should also delete Candidates mention_extractor.clear() assert session.query(Mention).count() == 0 assert session.query(Part).count() == 0 assert session.query(Temp).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(Candidate).count() == 0
def test_mention_longest_match(caplog): """Test longest match filtering in mention extraction.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parsing PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction name_ngrams = MentionNgramsPart(n_max=3) place_ngrams = MentionNgramsTemp(n_max=4) Name = mention_subclass("Name") Place = mention_subclass("Place") def is_birthplace_table_row(mention): if not mention.sentence.is_tabular(): return False ngrams = get_row_ngrams(mention, lower=True) if "birth_place" in ngrams: return True else: return False birthplace_matcher = LambdaFunctionMatcher(func=is_birthplace_table_row, longest_match_only=False) mention_extractor = MentionExtractor( session, [Name, Place], [name_ngrams, place_ngrams], [PersonMatcher(), birthplace_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) mentions = session.query(Place).all() mention_spans = [x.context.get_span() for x in mentions] assert "Sinking Spring Farm" in mention_spans assert "Farm" in mention_spans assert len(mention_spans) == 23 birthplace_matcher = LambdaFunctionMatcher(func=is_birthplace_table_row, longest_match_only=True) mention_extractor = MentionExtractor( session, [Name, Place], [name_ngrams, place_ngrams], [PersonMatcher(), birthplace_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) mentions = session.query(Place).all() mention_spans = [x.context.get_span() for x in mentions] assert "Sinking Spring Farm" in mention_spans assert "Farm" not in mention_spans assert len(mention_spans) == 4
def test_parse_document_diseases(caplog): """Unit test of Parser on a single document. This tests both the structural and visual parse of the document. """ caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session() # SpaCy on mac has issue on parallel parseing if os.name == "posix": PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 2 docs_path = "tests/data/html_simple/" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the diseases document parser = Parser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) parser.apply(preprocessor, parallelism=PARALLEL) # Grab the diseases document doc = session.query(Document).order_by(Document.name).all()[0] assert doc.name == "diseases" logger.info("Doc: {}".format(doc)) for sentence in doc.sentences: logger.info(" Sentence: {}".format(sentence.text)) # Check captions assert len(doc.captions) == 2 caption = sorted(doc.sentences, key=lambda x: x.position)[20] assert caption.paragraph.caption.position == 0 assert caption.paragraph.caption.table.position == 0 assert caption.text == "Table 1: Infectious diseases and where to find them." assert caption.paragraph.position == 18 # Check figures assert len(doc.figures) == 0 # Check that doc has a table assert len(doc.tables) == 3 assert doc.tables[0].position == 0 assert doc.tables[0].document.name == "diseases" # Check that doc has cells assert len(doc.cells) == 25 sentence = sorted(doc.sentences, key=lambda x: x.position)[10] logger.info(" {}".format(sentence)) # Check the sentence's cell assert sentence.table.position == 0 assert sentence.cell.row_start == 2 assert sentence.cell.col_start == 1 assert sentence.cell.position == 4 # Test structural attributes assert sentence.xpath == "/html/body/table[1]/tbody/tr[3]/td[1]/p" assert sentence.html_tag == "p" assert sentence.html_attrs == ["class=s6", "style=padding-top: 1pt"] # Test visual attributes assert sentence.page == [1, 1, 1] assert sentence.top == [342, 296, 356] assert sentence.left == [318, 369, 318] # Test lingual attributes assert sentence.ner_tags == ["O", "O", "GPE"] assert sentence.dep_labels == ["ROOT", "prep", "pobj"] assert len(doc.sentences) == 37
from fonduer import Meta, init_logging # Configure logging for Fonduer init_logging(log_dir="logs") session = Meta.init(conn_string).Session() from fonduer.parser.preprocessors import HTMLDocPreprocessor from fonduer.parser import Parser docs_path = "data/train/" doc_preprocessor = HTMLDocPreprocessor(docs_path) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) from fonduer.parser.models import Document, Sentence print(f"Documents: {session.query(Document).count()}") print(f"Sentences: {session.query(Sentence).count()}") train_docs = session.query(Document).order_by(Document.name).all() # Mention from fonduerconfig import mention_classes, mention_spaces, matchers, candidate_classes from fonduer.candidates import MentionExtractor mention_extractor = MentionExtractor( session,
def test_parse_md_paragraphs(caplog): """Unit test of Paragraph parsing.""" caplog.set_level(logging.INFO) session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session() PARALLEL = 1 max_docs = 1 docs_path = "tests/data/html_simple/md_para.html" pdf_path = "tests/data/pdf_simple/md_para.pdf" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the md document parser = Parser(structural=True, tabular=True, lingual=True, visual=True, pdf_path=pdf_path) parser.apply(preprocessor, parallelism=PARALLEL) # Grab the document doc = session.query(Document).order_by(Document.name).all()[0] assert doc.name == "md_para" # Check that doc has a figure assert len(doc.figures) == 6 assert doc.figures[0].url == "http://placebear.com/200/200" assert doc.figures[0].position == 0 assert doc.figures[0].section.position == 0 assert len(doc.figures[0].captions) == 0 assert doc.figures[0].stable_id == "md_para::figure:0" assert doc.figures[0].cell.position == 13 assert (doc.figures[2].url == "http://html5doctor.com/wp-content/uploads/2010/03/kookaburra.jpg") assert doc.figures[2].position == 2 assert len(doc.figures[2].captions) == 1 assert len(doc.figures[2].captions[0].paragraphs[0].sentences) == 3 assert (doc.figures[2].captions[0].paragraphs[0].sentences[0].text == "Australian Birds.") assert len(doc.figures[4].captions) == 0 assert (doc.figures[4].url == "http://html5doctor.com/wp-content/uploads/2010/03/pelican.jpg") # Check that doc has a table assert len(doc.tables) == 1 assert doc.tables[0].position == 0 assert doc.tables[0].section.position == 0 # Check that doc has cells assert len(doc.cells) == 16 cells = list(doc.cells) assert cells[0].row_start == 0 assert cells[0].col_start == 0 assert cells[0].position == 0 assert cells[0].table.position == 0 assert cells[10].row_start == 2 assert cells[10].col_start == 2 assert cells[10].position == 10 assert cells[10].table.position == 0 # Check that doc has sentences assert len(doc.sentences) == 51 sentences = sorted(doc.sentences, key=lambda x: x.position) sent1 = sentences[1] sent2 = sentences[2] sent3 = sentences[3] assert sent1.text == "This is some basic, sample markdown." assert sent2.text == ( "Unlike the other markdown document, however, " "this document actually contains paragraphs of text.") assert sent1.paragraph.position == 1 assert sent1.section.position == 0 assert sent2.paragraph.position == 1 assert sent2.section.position == 0 assert sent3.paragraph.position == 1 assert sent3.section.position == 0 assert len(doc.paragraphs) == 46 assert len(doc.paragraphs[1].sentences) == 3 assert len(doc.paragraphs[2].sentences) == 1
def test_parse_md_details(caplog): """Unit test of the final results stored in the database of the md document. This test only looks at the final results such that the implementation of the ParserUDF's apply() can be modified. """ caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session() PARALLEL = 1 max_docs = 1 docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/md.pdf" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the md document parser = Parser(structural=True, tabular=True, lingual=True, visual=True, pdf_path=pdf_path) parser.apply(preprocessor, parallelism=PARALLEL) # Grab the md document doc = session.query(Document).order_by(Document.name).all()[0] assert doc.name == "md" # Check that doc has a figure assert len(doc.figures) == 1 assert doc.figures[0].url == "http://placebear.com/200/200" assert doc.figures[0].position == 0 assert doc.figures[0].section.position == 0 assert doc.figures[0].stable_id == "md::figure:0" # Check that doc has a table assert len(doc.tables) == 1 assert doc.tables[0].position == 0 assert doc.tables[0].section.position == 0 assert doc.tables[0].document.name == "md" # Check that doc has cells assert len(doc.cells) == 16 cells = list(doc.cells) assert cells[0].row_start == 0 assert cells[0].col_start == 0 assert cells[0].position == 0 assert cells[0].document.name == "md" assert cells[0].table.position == 0 assert cells[10].row_start == 2 assert cells[10].col_start == 2 assert cells[10].position == 10 assert cells[10].document.name == "md" assert cells[10].table.position == 0 # Check that doc has sentences assert len(doc.sentences) == 45 sent = sorted(doc.sentences, key=lambda x: x.position)[25] assert sent.text == "Spicy" assert sent.table.position == 0 assert sent.table.section.position == 0 assert sent.cell.row_start == 0 assert sent.cell.col_start == 2 logger.info("Doc: {}".format(doc)) for i, sentence in enumerate(doc.sentences): logger.info(" Sentence[{}]: {}".format(i, sentence.text)) header = sorted(doc.sentences, key=lambda x: x.position)[0] # Test structural attributes assert header.xpath == "/html/body/h1" assert header.html_tag == "h1" assert header.html_attrs == ["id=sample-markdown"] # Test visual attributes assert header.page == [1, 1] assert header.top == [35, 35] assert header.bottom == [61, 61] assert header.right == [111, 231] assert header.left == [35, 117] # Test lingual attributes assert header.ner_tags == ["O", "O"] assert header.dep_labels == ["compound", "ROOT"]
def test_cand_gen(caplog): """Test extracting candidates from mentions from documents.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parseing if os.name == "posix": PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 10 session = Meta.init("postgres://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing num_docs = session.query(Document).count() if num_docs != max_docs: logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 5892 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") with pytest.raises(ValueError): mention_extractor = MentionExtractor( [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): mention_extractor = MentionExtractor( [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor = MentionExtractor( [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 234 assert session.query(Volt).count() == 108 assert session.query(Temp).count() == 118 part = session.query(Part).order_by(Part.id).all()[0] volt = session.query(Volt).order_by(Volt.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info("Part: {}".format(part.span)) logger.info("Volt: {}".format(volt.span)) logger.info("Temp: {}".format(temp.span)) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) candidate_extractor = CandidateExtractor( [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 3385 assert session.query(PartVolt).count() == 3364 assert session.query(Candidate).count() == 6749 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].volts) == 33 assert len(docs[0].temps) == 18 # Test that deletion of a Candidate does not delete the Mention session.query(PartTemp).delete() assert session.query(PartTemp).count() == 0 assert session.query(Temp).count() == 118 assert session.query(Part).count() == 234 # Test deletion of Candidate if Mention is deleted assert session.query(PartVolt).count() == 3364 assert session.query(Volt).count() == 108 session.query(Volt).delete() assert session.query(Volt).count() == 0 assert session.query(PartVolt).count() == 0
def test_e2e_logistic_regression(caplog): """Run an end-to-end test on documents of the hardware domain.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parseing if os.name == "posix": PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 12 session = Meta.init("postgres://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) num_docs = session.query(Document).count() if num_docs != max_docs: logger.info("Parsing...") corpus_parser = Parser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info("Docs: {}".format(num_docs)) assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info("Sentences: {}".format(num_sentences)) # Divide into test and train docs = session.query(Document).order_by(Document.name).all() ld = len(docs) assert len(docs[0].sentences) == 828 assert len(docs[1].sentences) == 706 assert len(docs[2].sentences) == 819 assert len(docs[3].sentences) == 684 assert len(docs[4].sentences) == 552 assert len(docs[5].sentences) == 758 assert len(docs[6].sentences) == 597 assert len(docs[7].sentences) == 165 assert len(docs[8].sentences) == 250 assert len(docs[9].sentences) == 533 assert len(docs[10].sentences) == 354 assert len(docs[11].sentences) == 547 # Check table numbers assert len(docs[0].tables) == 9 assert len(docs[1].tables) == 9 assert len(docs[2].tables) == 14 assert len(docs[3].tables) == 11 assert len(docs[4].tables) == 11 assert len(docs[5].tables) == 10 assert len(docs[6].tables) == 10 assert len(docs[7].tables) == 2 assert len(docs[8].tables) == 7 assert len(docs[9].tables) == 10 assert len(docs[10].tables) == 6 assert len(docs[11].tables) == 9 # Check figure numbers assert len(docs[0].figures) == 32 assert len(docs[1].figures) == 11 assert len(docs[2].figures) == 38 assert len(docs[3].figures) == 31 assert len(docs[4].figures) == 7 assert len(docs[5].figures) == 38 assert len(docs[6].figures) == 10 assert len(docs[7].figures) == 31 assert len(docs[8].figures) == 4 assert len(docs[9].figures) == 27 assert len(docs[10].figures) == 5 assert len(docs[11].figures) == 27 # Check caption numbers assert len(docs[0].captions) == 0 assert len(docs[1].captions) == 0 assert len(docs[2].captions) == 0 assert len(docs[3].captions) == 0 assert len(docs[4].captions) == 0 assert len(docs[5].captions) == 0 assert len(docs[6].captions) == 0 assert len(docs[7].captions) == 0 assert len(docs[8].captions) == 0 assert len(docs[9].captions) == 0 assert len(docs[10].captions) == 0 assert len(docs[11].captions) == 0 train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor([Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 299 assert session.query(Temp).count() == 127 # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor([PartTemp], throttlers=[temp_throttler]) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3201 assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61 assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 420 train_cands = session.query(PartTemp).filter(PartTemp.split == 0).all() featurizer = FeatureAnnotator(PartTemp) F_train = featurizer.apply(split=0, replace_key_set=True, parallelism=PARALLEL) logger.info(F_train.shape) F_dev = featurizer.apply(split=1, replace_key_set=False, parallelism=PARALLEL) logger.info(F_dev.shape) F_test = featurizer.apply(split=2, replace_key_set=False, parallelism=PARALLEL) logger.info(F_test.shape) gold_file = "tests/data/hardware_tutorial_gold.csv" load_hardware_labels(session, PartTemp, gold_file, ATTRIBUTE, annotator_name="gold") stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = LabelAnnotator(PartTemp, lfs=stg_temp_lfs) L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL) logger.info(L_train.shape) load_gold_labels(session, annotator_name="gold", split=0) gen_model = GenerativeModel() gen_model.train(L_train, epochs=500, decay=0.9, step_size=0.001 / L_train.shape[0], reg_param=0) logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy)) load_gold_labels(session, annotator_name="gold", split=1) train_marginals = gen_model.marginals(L_train) disc_model = LogisticRegression() disc_model.train((train_cands, F_train), train_marginals, n_epochs=200, lr=0.001) load_gold_labels(session, annotator_name="gold", split=2) test_candidates = [ F_test.get_candidate(session, i) for i in range(F_test.shape[0]) ] test_score = disc_model.predictions((test_candidates, F_test)) true_pred = [ test_candidates[_] for _ in np.nditer(np.where(test_score > 0)) ] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler = LabelAnnotator(PartTemp, lfs=stg_temp_lfs_2) L_train = labeler.apply(split=0, clear=False, update_keys=True, update_values=True, parallelism=PARALLEL) gen_model = GenerativeModel() gen_model.train(L_train, epochs=500, decay=0.9, step_size=0.001 / L_train.shape[0], reg_param=0) train_marginals = gen_model.marginals(L_train) disc_model = LogisticRegression() disc_model.train((train_cands, F_train), train_marginals, n_epochs=200, lr=0.001) test_score = disc_model.predictions((test_candidates, F_test)) true_pred = [ test_candidates[_] for _ in np.nditer(np.where(test_score > 0)) ] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7
def test_e2e(caplog): """Run an end-to-end test on documents of the hardware domain.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parsing if os.name == "posix": logger.info("Using single core.") PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 12 session = Meta.init("postgres://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info("Docs: {}".format(num_docs)) assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info("Sentences: {}".format(num_sentences)) # Divide into test and train docs = session.query(Document).order_by(Document.name).all() ld = len(docs) assert len(docs[0].sentences) == 799 assert len(docs[1].sentences) == 663 assert len(docs[2].sentences) == 784 assert len(docs[3].sentences) == 661 assert len(docs[4].sentences) == 513 assert len(docs[5].sentences) == 700 assert len(docs[6].sentences) == 528 assert len(docs[7].sentences) == 161 assert len(docs[8].sentences) == 228 assert len(docs[9].sentences) == 511 assert len(docs[10].sentences) == 331 assert len(docs[11].sentences) == 528 # Check table numbers assert len(docs[0].tables) == 9 assert len(docs[1].tables) == 9 assert len(docs[2].tables) == 14 assert len(docs[3].tables) == 11 assert len(docs[4].tables) == 11 assert len(docs[5].tables) == 10 assert len(docs[6].tables) == 10 assert len(docs[7].tables) == 2 assert len(docs[8].tables) == 7 assert len(docs[9].tables) == 10 assert len(docs[10].tables) == 6 assert len(docs[11].tables) == 9 # Check figure numbers assert len(docs[0].figures) == 32 assert len(docs[1].figures) == 11 assert len(docs[2].figures) == 38 assert len(docs[3].figures) == 31 assert len(docs[4].figures) == 7 assert len(docs[5].figures) == 38 assert len(docs[6].figures) == 10 assert len(docs[7].figures) == 31 assert len(docs[8].figures) == 4 assert len(docs[9].figures) == 27 assert len(docs[10].figures) == 5 assert len(docs[11].figures) == 27 # Check caption numbers assert len(docs[0].captions) == 0 assert len(docs[1].captions) == 0 assert len(docs[2].captions) == 0 assert len(docs[3].captions) == 0 assert len(docs[4].captions) == 0 assert len(docs[5].captions) == 0 assert len(docs[6].captions) == 0 assert len(docs[7].captions) == 0 assert len(docs[8].captions) == 0 assert len(docs[9].captions) == 0 assert len(docs[10].captions) == 0 assert len(docs[11].captions) == 0 train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 299 assert session.query(Temp).count() == 134 assert len(mention_extractor.get_mentions()) == 2 assert len(mention_extractor.get_mentions()[0]) == 299 assert (len( mention_extractor.get_mentions(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 70) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp], throttlers=[temp_throttler]) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3346 assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61 assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 420 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) dev_cands = candidate_extractor.get_candidates(split=1) test_cands = candidate_extractor.get_candidates(split=2) assert len(train_cands) == 1 assert len(train_cands[0]) == 3346 assert (len( candidate_extractor.get_candidates(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 1178) # Featurization featurizer = Featurizer(session, [PartTemp]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 61 assert session.query(FeatureKey).count() == 676 # Test Dropping FeatureKey featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NFP NN NFP]"]) assert session.query(FeatureKey).count() == 675 session.query(Feature).delete() featurizer.apply(split=0, train=True, parallelism=1) assert session.query(Feature).count() == 3346 assert session.query(FeatureKey).count() == 3578 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (3346, 3578) assert len(featurizer.get_keys()) == 3578 featurizer.apply(split=1, parallelism=PARALLEL) assert session.query(Feature).count() == 3407 assert session.query(FeatureKey).count() == 3578 F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (61, 3578) featurizer.apply(split=2, parallelism=PARALLEL) assert session.query(Feature).count() == 3827 assert session.query(FeatureKey).count() == 3578 F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (420, 3578) gold_file = "tests/data/hardware_tutorial_gold.csv" load_hardware_labels(session, PartTemp, gold_file, ATTRIBUTE, annotator_name="gold") assert session.query(GoldLabel).count() == 3827 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 3346 assert session.query(LabelKey).count() == 6 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3346, 6) assert len(labeler.get_keys()) == 6 L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (3346, 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (3346, 1) gen_model = LabelLearner(cardinalities=2) gen_model.train(L_train[0], n_epochs=500, print_every=100) train_marginals = gen_model.predict_proba(L_train[0])[:, 1] disc_model = LogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2], parallelism=PARALLEL) assert session.query(Label).count() == 3346 assert session.query(LabelKey).count() == 13 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3346, 13) gen_model = LabelLearner(cardinalities=2) gen_model.train(L_train[0], n_epochs=500, print_every=100) train_marginals = gen_model.predict_proba(L_train[0])[:, 1] disc_model = LogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7 # Testing LSTM disc_model = LSTM() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7 # Testing Sparse Logistic Regression disc_model = SparseLogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.9) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7
def test_e2e(caplog): """Run an end-to-end test on documents of the hardware domain.""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 12 fonduer.init_logging( log_dir="log_folder", format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s", level=logging.INFO, ) session = fonduer.Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info(f"Sentences: {num_sentences}") # Divide into test and train docs = sorted(corpus_parser.get_documents()) last_docs = sorted(corpus_parser.get_last_documents()) ld = len(docs) assert ld == len(last_docs) assert len(docs[0].sentences) == len(last_docs[0].sentences) assert len(docs[0].sentences) == 799 assert len(docs[1].sentences) == 663 assert len(docs[2].sentences) == 784 assert len(docs[3].sentences) == 661 assert len(docs[4].sentences) == 513 assert len(docs[5].sentences) == 700 assert len(docs[6].sentences) == 528 assert len(docs[7].sentences) == 161 assert len(docs[8].sentences) == 228 assert len(docs[9].sentences) == 511 assert len(docs[10].sentences) == 331 assert len(docs[11].sentences) == 528 # Check table numbers assert len(docs[0].tables) == 9 assert len(docs[1].tables) == 9 assert len(docs[2].tables) == 14 assert len(docs[3].tables) == 11 assert len(docs[4].tables) == 11 assert len(docs[5].tables) == 10 assert len(docs[6].tables) == 10 assert len(docs[7].tables) == 2 assert len(docs[8].tables) == 7 assert len(docs[9].tables) == 10 assert len(docs[10].tables) == 6 assert len(docs[11].tables) == 9 # Check figure numbers assert len(docs[0].figures) == 32 assert len(docs[1].figures) == 11 assert len(docs[2].figures) == 38 assert len(docs[3].figures) == 31 assert len(docs[4].figures) == 7 assert len(docs[5].figures) == 38 assert len(docs[6].figures) == 10 assert len(docs[7].figures) == 31 assert len(docs[8].figures) == 4 assert len(docs[9].figures) == 27 assert len(docs[10].figures) == 5 assert len(docs[11].figures) == 27 # Check caption numbers assert len(docs[0].captions) == 0 assert len(docs[1].captions) == 0 assert len(docs[2].captions) == 0 assert len(docs[3].captions) == 0 assert len(docs[4].captions) == 0 assert len(docs[5].captions) == 0 assert len(docs[6].captions) == 0 assert len(docs[7].captions) == 0 assert len(docs[8].captions) == 0 assert len(docs[9].captions) == 0 assert len(docs[10].captions) == 0 assert len(docs[11].captions) == 0 train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 299 assert session.query(Temp).count() == 138 assert session.query(Volt).count() == 140 assert len(mention_extractor.get_mentions()) == 3 assert len(mention_extractor.get_mentions()[0]) == 299 assert (len( mention_extractor.get_mentions(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 70) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493 assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61 assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416 assert session.query(PartVolt).count() == 4282 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) assert len(train_cands) == 2 assert len(train_cands[0]) == 3493 assert (len( candidate_extractor.get_candidates(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 1432) # Featurization featurizer = Featurizer(session, [PartTemp, PartVolt]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 214 assert session.query(FeatureKey).count() == 1128 # Test Dropping FeatureKey # Should force a row deletion featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"]) assert session.query(FeatureKey).count() == 1127 # Should only remove the part_volt as a relation and leave part_temp assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) assert session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"] assert session.query(FeatureKey).count() == 1127 # Inserting the removed key featurizer.upsert_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt]) assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } assert session.query(FeatureKey).count() == 1127 # Removing the key again featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) # Removing the last relation from a key should delete the row featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp]) assert session.query(FeatureKey).count() == 1126 session.query(Feature).delete(synchronize_session="fetch") session.query(FeatureKey).delete(synchronize_session="fetch") featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 6478 assert session.query(FeatureKey).count() == 4122 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (3493, 4122) assert F_train[1].shape == (2985, 4122) assert len(featurizer.get_keys()) == 4122 featurizer.apply(split=1, parallelism=PARALLEL) assert session.query(Feature).count() == 6692 assert session.query(FeatureKey).count() == 4122 F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (61, 4122) assert F_dev[1].shape == (153, 4122) featurizer.apply(split=2, parallelism=PARALLEL) assert session.query(Feature).count() == 8252 assert session.query(FeatureKey).count() == 4122 F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (416, 4122) assert F_test[1].shape == (1144, 4122) gold_file = "tests/data/hardware_tutorial_gold.csv" load_hardware_labels(session, PartTemp, gold_file, ATTRIBUTE, annotator_name="gold") assert session.query(GoldLabel).count() == 3970 load_hardware_labels(session, PartVolt, gold_file, ATTRIBUTE, annotator_name="gold") assert session.query(GoldLabel).count() == 8252 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] ce_v_max_lfs = [ LF_bad_keywords_in_row, LF_current_in_row, LF_non_ce_voltages_in_row, ] labeler = Labeler(session, [PartTemp, PartVolt]) with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply(split=0, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 9 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 9) assert L_train[1].shape == (2985, 9) assert len(labeler.get_keys()) == 9 L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (3493, 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (3493, 1) gen_model = LabelModel(k=2) gen_model.train_model(L_train[0], n_epochs=500, print_every=100) train_marginals = gen_model.predict_proba(L_train[0]) disc_model = LogisticRegression() disc_model.train( (train_cands[0], F_train[0]), train_marginals, X_dev=(train_cands[0], F_train[0]), Y_dev=np.array(L_train_gold[0].todense()).reshape(-1), b=0.6, pos_label=TRUE, n_epochs=5, lr=0.001, ) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [ test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE)) ] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 16 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 16) gen_model = LabelModel(k=2) gen_model.train_model(L_train[0], n_epochs=500, print_every=100) train_marginals = gen_model.predict_proba(L_train[0]) disc_model = LogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [ test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE)) ] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing LSTM disc_model = LSTM() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [ test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE)) ] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing Sparse Logistic Regression disc_model = SparseLogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [ test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE)) ] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing Sparse LSTM disc_model = SparseLSTM() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [ test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE)) ] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Evaluate mention level scores L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold") Y_test = np.array(L_test_gold[0].todense()).reshape(-1) scores = disc_model.score((test_cands[0], F_test[0]), Y_test, b=0.6, pos_label=TRUE) logger.info(scores) assert scores["f1"] > 0.6
def test_parse_style(caplog): """Test style tag parsing.""" caplog.set_level(logging.INFO) logger = logging.getLogger(__name__) session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session() # SpaCy on mac has issue on parallel parseing if os.name == "posix": PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 1 docs_path = "tests/data/html_extended/ext_diseases.html" pdf_path = "tests/data/pdf_extended/ext_diseases.pdf" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the md document parser = Parser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) parser.apply(preprocessor, parallelism=PARALLEL) # Grab the document doc = session.query(Document).order_by(Document.name).all()[0] # Grab the sentences parsed by the Parser sentences = list(session.query(Sentence).order_by(Sentence.position).all()) logger.warning("Doc: {}".format(doc)) for i, sentence in enumerate(sentences): logger.warning(" Sentence[{}]: {}".format(i, sentence.html_attrs)) # sentences for testing sub_sentences = [ { "index": 6, "attr": [ "class=col-header", "hobbies=work:hard;play:harder", "type=phenotype", "style=background: #f1f1f1; color: aquamarine; font-size: 18px;", ], }, { "index": 9, "attr": ["class=row-header", "style=background: #f1f1f1;"] }, { "index": 11, "attr": ["class=cell", "style=text-align: center;"] }, ] # Assertions assert all(sentences[p["index"]].html_attrs == p["attr"] for p in sub_sentences)
def test_e2e(database_session): """Run an end-to-end test on documents of the hardware domain.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 12 fonduer.init_logging( format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s", level=logging.INFO, ) session = database_session docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual_parser=PdfVisualParser(pdf_path), ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info(f"Sentences: {num_sentences}") # Divide into test and train docs = sorted(corpus_parser.get_documents()) last_docs = sorted(corpus_parser.get_last_documents()) ld = len(docs) assert ld == len(last_docs) assert len(docs[0].sentences) == len(last_docs[0].sentences) train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert len(mention_extractor.get_mentions()) == 3 assert len(mention_extractor.get_mentions(docs)) == 3 # Candidate Extraction candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) # Candidate lists should be deterministically sorted. assert ("112823::implicit_span_mention:11059:11065:part_expander:0" == train_cands[0][0][0].context.get_stable_id()) assert ("112823::implicit_span_mention:2752:2754:temp_expander:0" == train_cands[0][0][1].context.get_stable_id()) assert len(train_cands) == 2 assert len(candidate_extractor.get_candidates(docs)) == 2 # Featurization featurizer = Featurizer(session, [PartTemp, PartVolt]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 214 num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 1278 # Test Dropping FeatureKey # Should force a row deletion featurizer.drop_keys(["BASIC_e0_CONTAINS_WORDS_[BC182]"]) assert session.query(FeatureKey).count() == num_feature_keys - 1 # Should only remove the part_volt as a relation and leave part_temp assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) assert session.query(FeatureKey).filter( FeatureKey.name == "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"] assert session.query(FeatureKey).count() == num_feature_keys - 1 # Inserting the removed key featurizer.upsert_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt]) assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } assert session.query(FeatureKey).count() == num_feature_keys - 1 # Removing the key again featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) # Removing the last relation from a key should delete the row featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp]) assert session.query(FeatureKey).count() == num_feature_keys - 2 session.query(Feature).delete(synchronize_session="fetch") session.query(FeatureKey).delete(synchronize_session="fetch") featurizer.apply(split=0, train=True, parallelism=PARALLEL) # the number of Features should equals to the total number of train candidates num_features = session.query(Feature).count() assert num_features == len(train_cands[0]) + len(train_cands[1]) num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 4629 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (len(train_cands[0]), num_feature_keys) assert F_train[1].shape == (len(train_cands[1]), num_feature_keys) assert len(featurizer.get_keys()) == num_feature_keys featurizer.apply(split=1, parallelism=PARALLEL) # the number of Features should increate by the total number of dev candidates num_features += len(dev_cands[0]) + len(dev_cands[1]) assert session.query(Feature).count() == num_features assert session.query(FeatureKey).count() == num_feature_keys F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (len(dev_cands[0]), num_feature_keys) assert F_dev[1].shape == (len(dev_cands[1]), num_feature_keys) featurizer.apply(split=2, parallelism=PARALLEL) # the number of Features should increate by the total number of test candidates num_features += len(test_cands[0]) + len(test_cands[1]) assert session.query(Feature).count() == num_features assert session.query(FeatureKey).count() == num_feature_keys F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (len(test_cands[0]), num_feature_keys) assert F_test[1].shape == (len(test_cands[1]), num_feature_keys) gold_file = "tests/data/hardware_tutorial_gold.csv" labeler = Labeler(session, [PartTemp, PartVolt]) # This should raise an error, since gold labels are not yet loaded. with pytest.raises(ValueError): _ = labeler.get_gold_labels(train_cands, annotator="gold") labeler.apply( docs=last_docs, lfs=[[gold], [gold]], table=GoldLabel, train=True, parallelism=PARALLEL, ) # All candidates should now be gold-labeled. assert session.query(GoldLabel).count() == session.query(Candidate).count() stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] ce_v_max_lfs = [ LF_bad_keywords_in_row, LF_current_in_row, LF_non_ce_voltages_in_row, ] with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply( docs=train_docs, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL, ) assert session.query(Label).count() == len(train_cands[0]) + len( train_cands[1]) num_label_keys = session.query(LabelKey).count() assert num_label_keys == 9 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), num_label_keys) assert L_train[1].shape == (len(train_cands[1]), num_label_keys) assert len(labeler.get_keys()) == num_label_keys # Test Dropping LabelerKey labeler.drop_keys(["LF_storage_row"]) assert len(labeler.get_keys()) == num_label_keys - 1 # Test Upserting LabelerKey labeler.upsert_keys(["LF_storage_row"]) assert "LF_storage_row" in [label.name for label in labeler.get_keys()] L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (len(train_cands[0]), 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (len(train_cands[0]), 1) label_model = LabelModel(cardinality=2) label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) # Collect word counter word_counter = collect_word_counter(train_cands) emmental.init(fonduer.Meta.log_path) # Training config config = { "meta_config": { "verbose": False }, "model_config": { "model_path": None, "device": 0, "dataparallel": False }, "learner_config": { "n_epochs": 5, "optimizer_config": { "lr": 0.001, "l2": 0.0 }, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": { f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min" }, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.Meta.update_config(config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=300, specials=specials) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0], emb_layer.word2id, 2), split="test", batch_size=100, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6) true_pred = [test_cands[0][_] for _ in positive[0]] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL) assert session.query(Label).count() == len(train_cands[0]) + len( train_cands[1]) num_label_keys = session.query(LabelKey).count() assert num_label_keys == 16 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), num_label_keys) label_model = LabelModel(cardinality=2) label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) valid_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, np.argmax(train_marginals, axis=1), train_idxs, ), split="valid", batch_size=100, shuffle=False, ) # Testing STL LogisticRegression emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task( ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression", mode="STL", ) model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader, valid_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing STL LSTM emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LSTM", mode="STL") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing MTL LogisticRegression emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task( ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression", mode="MTL", ) model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader, valid_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing LSTM emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LSTM", mode="MTL") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7
# Configure logging for Fonduer init_logging(log_dir=f"logs_{ATTRIBUTE}", level=logging.INFO) # DEBUG LOGGING session = Meta.init(conn_string).Session() # Initialize NLP library for vector similarities os.system(f"python3 -m spacy download en_core_web_lg") # 3.) Process documents into train,dev,test print("\n#3 Process Document into train, dev, test sets") # parse documents has_documents = session.query(Document).count() > 0 corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) if (not has_documents): doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) print(f"Documents: {session.query(Document).count()}") print(f"Sentences: {session.query(Sentence).count()}") # split documents docs = session.query(Document).order_by(Document.name).all() ld = len(docs) train_docs = set() dev_docs = set() test_docs = set()
def test_incremental(database_session): """Run an end-to-end test on incremental additions.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 1 session = database_session docs_path = "tests/data/html/dtc114w.html" pdf_path = "tests/data/pdf/dtc114w.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs docs = corpus_parser.get_documents() last_docs = corpus_parser.get_documents() assert len(docs[0].sentences) == len(last_docs[0].sentences) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Test if clear=True works mention_extractor.apply(docs, parallelism=PARALLEL, clear=True) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Candidate Extraction candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70 assert session.query(Candidate).count() == 70 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 70 # Featurization featurizer = Featurizer(session, [PartTemp]) featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == len(train_cands[0]) num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 514 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (len(train_cands[0]), num_feature_keys) assert len(featurizer.get_keys()) == num_feature_keys # Test Dropping FeatureKey featurizer.drop_keys(["BASIC_e1_LENGTH_1"]) assert session.query(FeatureKey).count() == num_feature_keys - 1 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == len(train_cands[0]) # Only 5 because LF_operating_row doesn't apply to the first test doc assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), 5) assert len(labeler.get_keys()) == 5 docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False) assert len(corpus_parser.get_documents()) == 2 new_docs = corpus_parser.get_last_documents() assert len(new_docs) == 1 assert new_docs[0].name == "112823" # Get mentions from just the new docs mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Test if existing mentions are skipped. mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Just run candidate extraction and assign to split 0 candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1501 # Test if existing candidates are skipped. candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1501 # Update features featurizer.update(new_docs, parallelism=PARALLEL) assert session.query(Feature).count() == len(train_cands[0]) num_feature_keys = session.query(FeatureKey).count() assert num_feature_keys == 2608 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (len(train_cands[0]), num_feature_keys) assert len(featurizer.get_keys()) == num_feature_keys # Update LF_storage_row. Now it always returns ABSTAIN. @labeling_function(name="LF_storage_row") def LF_storage_row_updated(c): return ABSTAIN stg_temp_lfs = [ LF_storage_row_updated, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] # Update Labels labeler.update(docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) assert session.query(Label).count() == len(train_cands[0]) # Only 5 because LF_storage_row doesn't apply to any doc (always ABSTAIN) num_label_keys = session.query(LabelKey).count() assert num_label_keys == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (len(train_cands[0]), num_label_keys) # Test clear featurizer.clear(train=True) assert session.query(FeatureKey).count() == 0
def _get_parser(self, session: Session) -> Parser: return Parser(session, structural=True, lingual=True)
def test_cand_gen(caplog): """Test extracting candidates from mentions from documents.""" caplog.set_level(logging.INFO) if platform == "darwin": logger.info("Using single core.") PARALLEL = 1 else: logger.info("Using two cores.") PARALLEL = 2 # Travis only gives 2 cores def do_nothing_matcher(fig): return True max_docs = 10 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 5548 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 234 assert session.query(Volt).count() == 107 assert session.query(Temp).count() == 136 assert session.query(Fig).count() == 223 part = session.query(Part).order_by(Part.id).all()[0] volt = session.query(Volt).order_by(Volt.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Volt: {volt.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[ temp_throttler, volt_throttler, volt_throttler, ], # Fail, mismatched arity ) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp], # Fail, mismatched arity throttlers=[temp_throttler, volt_throttler], ) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt]) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 4141 assert session.query(PartVolt).count() == 3610 assert session.query(Candidate).count() == 7751 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(PartVolt).count() == 0 # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor(session, [PartTemp, PartVolt], throttlers=[temp_throttler, None]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 3879 assert session.query(PartVolt).count() == 3610 assert session.query(Candidate).count() == 7489 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 3879 assert session.query(PartVolt).count() == 3266 assert session.query(Candidate).count() == 7145 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].volts) == 33 assert len(docs[0].temps) == 24 # Test that deletion of a Candidate does not delete the Mention session.query(PartTemp).delete(synchronize_session="fetch") assert session.query(PartTemp).count() == 0 assert session.query(Temp).count() == 136 assert session.query(Part).count() == 234 # Test deletion of Candidate if Mention is deleted assert session.query(PartVolt).count() == 3266 assert session.query(Volt).count() == 107 session.query(Volt).delete(synchronize_session="fetch") assert session.query(Volt).count() == 0 assert session.query(PartVolt).count() == 0
def test_e2e(): """Run an end-to-end test on documents of the hardware domain.""" PARALLEL = 4 max_docs = 12 fonduer.init_logging( log_dir="log_folder", format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s", level=logging.INFO, ) session = fonduer.Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info(f"Sentences: {num_sentences}") # Divide into test and train docs = sorted(corpus_parser.get_documents()) last_docs = sorted(corpus_parser.get_last_documents()) ld = len(docs) assert ld == len(last_docs) assert len(docs[0].sentences) == len(last_docs[0].sentences) assert len(docs[0].sentences) == 799 assert len(docs[1].sentences) == 663 assert len(docs[2].sentences) == 784 assert len(docs[3].sentences) == 661 assert len(docs[4].sentences) == 513 assert len(docs[5].sentences) == 700 assert len(docs[6].sentences) == 528 assert len(docs[7].sentences) == 161 assert len(docs[8].sentences) == 228 assert len(docs[9].sentences) == 511 assert len(docs[10].sentences) == 331 assert len(docs[11].sentences) == 528 # Check table numbers assert len(docs[0].tables) == 9 assert len(docs[1].tables) == 9 assert len(docs[2].tables) == 14 assert len(docs[3].tables) == 11 assert len(docs[4].tables) == 11 assert len(docs[5].tables) == 10 assert len(docs[6].tables) == 10 assert len(docs[7].tables) == 2 assert len(docs[8].tables) == 7 assert len(docs[9].tables) == 10 assert len(docs[10].tables) == 6 assert len(docs[11].tables) == 9 # Check figure numbers assert len(docs[0].figures) == 32 assert len(docs[1].figures) == 11 assert len(docs[2].figures) == 38 assert len(docs[3].figures) == 31 assert len(docs[4].figures) == 7 assert len(docs[5].figures) == 38 assert len(docs[6].figures) == 10 assert len(docs[7].figures) == 31 assert len(docs[8].figures) == 4 assert len(docs[9].figures) == 27 assert len(docs[10].figures) == 5 assert len(docs[11].figures) == 27 # Check caption numbers assert len(docs[0].captions) == 0 assert len(docs[1].captions) == 0 assert len(docs[2].captions) == 0 assert len(docs[3].captions) == 0 assert len(docs[4].captions) == 0 assert len(docs[5].captions) == 0 assert len(docs[6].captions) == 0 assert len(docs[7].captions) == 0 assert len(docs[8].captions) == 0 assert len(docs[9].captions) == 0 assert len(docs[10].captions) == 0 assert len(docs[11].captions) == 0 train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 299 assert session.query(Temp).count() == 138 assert session.query(Volt).count() == 140 assert len(mention_extractor.get_mentions()) == 3 assert len(mention_extractor.get_mentions()[0]) == 299 assert ( len( mention_extractor.get_mentions( docs=[session.query(Document).filter(Document.name == "112823").first()] )[0] ) == 70 ) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler] ) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493 assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61 assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416 assert session.query(PartVolt).count() == 4282 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) assert len(train_cands) == 2 assert len(train_cands[0]) == 3493 assert ( len( candidate_extractor.get_candidates( docs=[session.query(Document).filter(Document.name == "112823").first()] )[0] ) == 1432 ) # Featurization featurizer = Featurizer(session, [PartTemp, PartVolt]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 214 assert session.query(FeatureKey).count() == 1260 # Test Dropping FeatureKey # Should force a row deletion featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"]) assert session.query(FeatureKey).count() == 1259 # Should only remove the part_volt as a relation and leave part_temp assert set( session.query(FeatureKey) .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]") .one() .candidate_classes ) == {"part_temp", "part_volt"} featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) assert session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]" ).one().candidate_classes == ["part_temp"] assert session.query(FeatureKey).count() == 1259 # Inserting the removed key featurizer.upsert_keys( ["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt] ) assert set( session.query(FeatureKey) .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]") .one() .candidate_classes ) == {"part_temp", "part_volt"} assert session.query(FeatureKey).count() == 1259 # Removing the key again featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) # Removing the last relation from a key should delete the row featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp]) assert session.query(FeatureKey).count() == 1258 session.query(Feature).delete(synchronize_session="fetch") session.query(FeatureKey).delete(synchronize_session="fetch") featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 6478 assert session.query(FeatureKey).count() == 4538 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (3493, 4538) assert F_train[1].shape == (2985, 4538) assert len(featurizer.get_keys()) == 4538 featurizer.apply(split=1, parallelism=PARALLEL) assert session.query(Feature).count() == 6692 assert session.query(FeatureKey).count() == 4538 F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (61, 4538) assert F_dev[1].shape == (153, 4538) featurizer.apply(split=2, parallelism=PARALLEL) assert session.query(Feature).count() == 8252 assert session.query(FeatureKey).count() == 4538 F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (416, 4538) assert F_test[1].shape == (1144, 4538) gold_file = "tests/data/hardware_tutorial_gold.csv" labeler = Labeler(session, [PartTemp, PartVolt]) labeler.apply( docs=last_docs, lfs=[[gold], [gold]], table=GoldLabel, train=True, parallelism=PARALLEL, ) assert session.query(GoldLabel).count() == 8252 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] ce_v_max_lfs = [ LF_bad_keywords_in_row, LF_current_in_row, LF_non_ce_voltages_in_row, ] with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply( docs=train_docs, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL, ) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 9 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 9) assert L_train[1].shape == (2985, 9) assert len(labeler.get_keys()) == 9 # Test Dropping LabelerKey labeler.drop_keys(["LF_storage_row"]) assert len(labeler.get_keys()) == 8 # Test Upserting LabelerKey labeler.upsert_keys(["LF_storage_row"]) assert "LF_storage_row" in [label.name for label in labeler.get_keys()] L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (3493, 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (3493, 1) label_model = LabelModel() label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) # Collect word counter word_counter = collect_word_counter(train_cands) emmental.init(fonduer.Meta.log_path) # Training config config = { "meta_config": {"verbose": False}, "model_config": {"model_path": None, "device": 0, "dataparallel": False}, "learner_config": { "n_epochs": 5, "optimizer_config": {"lr": 0.001, "l2": 0.0}, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": {f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"}, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.Meta.update_config(config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule( word_counter=word_counter, word_dim=300, specials=specials ) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) tasks = create_task( ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression" ) model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, test_cands[0], F_test[0], emb_layer.word2id, 2 ), split="test", batch_size=100, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where(np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6) true_pred = [test_cands[0][_] for _ in positive[0]] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 16 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 16) label_model = LabelModel() label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task( ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression" ) model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where(np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing LSTM emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LSTM") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where(np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7
def test_incremental(caplog): """Run an end-to-end test on incremental additions.""" caplog.set_level(logging.INFO) PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/dtc114w.html" pdf_path = "tests/data/pdf/dtc114w.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs docs = corpus_parser.get_documents() last_docs = corpus_parser.get_documents() assert len(docs[0].sentences) == len(last_docs[0].sentences) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp], throttlers=[temp_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70 assert session.query(Candidate).count() == 70 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 70 # Featurization featurizer = Featurizer(session, [PartTemp]) featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 70 assert session.query(FeatureKey).count() == 491 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (70, 491) assert len(featurizer.get_keys()) == 491 # Test Dropping FeatureKey featurizer.drop_keys(["CORE_e1_LENGTH_1"]) assert session.query(FeatureKey).count() == 490 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 70 # Only 5 because LF_operating_row doesn't apply to the first test doc assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (70, 5) assert len(labeler.get_keys()) == 5 docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False) assert len(corpus_parser.get_documents()) == 2 new_docs = corpus_parser.get_last_documents() assert len(new_docs) == 1 assert new_docs[0].name == "112823" # Get mentions from just the new docs mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Just run candidate extraction and assign to split 0 candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1502 # Update features featurizer.update(new_docs, parallelism=PARALLEL) assert session.query(Feature).count() == 1502 assert session.query(FeatureKey).count() == 2424 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (1502, 2424) assert len(featurizer.get_keys()) == 2424 # Update Labels labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 1502 assert session.query(LabelKey).count() == 6 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (1502, 6) # Test clear featurizer.clear(train=True) assert session.query(FeatureKey).count() == 0
def test_binary_relation_feature_extraction(): """Test extracting candidates from mentions from documents.""" PARALLEL = 1 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgrams(n_max=1) temp_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert docs[0].name == "112823" assert session.query(Part).count() == 58 assert session.query(Temp).count() == 16 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) n_cands = session.query(PartTemp).count() # Featurization based on default feature library featurizer = Featurizer(session, [PartTemp]) # Test that featurization default feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_feats = session.query(FeatureKey).count() featurizer.clear(train=True) # Example feature extractor def feat_ext(candidates): candidates = candidates if isinstance(candidates, list) else [candidates] for candidate in candidates: yield candidate.id, f"cand_id_{candidate.id}", 1 # Featurization with one extra feature extractor feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_w_customized_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Example spurious feature extractor def bad_feat_ext(candidates): raise RuntimeError() # Featurization with a spurious feature extractor feature_extractors = FeatureExtractor( customize_feature_funcs=[bad_feat_ext]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor logger.info("Featurizing with a spurious feature extractor...") featurizer.apply(split=0, train=True, parallelism=PARALLEL) featurizer.clear(train=True) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization textual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_textual_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization tabular feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_tabular_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization structural feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_structural_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization visual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_visual_features = session.query(FeatureKey).count() featurizer.clear(train=True) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features) assert n_default_w_customized_features == n_default_feats + n_cands