def email_extract(docs, session, email_subclass, parallelism, clear=True): email_matcher = LambdaFunctionMatcher(func=email_mc, longest_match_only=True) email_space = MentionEmails() mention_extractor = MentionExtractor(session, [email_subclass], [email_space], [email_matcher]) mention_extractor.apply(docs, parallelism=parallelism, clear=clear)
def test_multimodal_cand(caplog): """Test multimodal candidate generation""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/radiology.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 35 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor = MentionExtractor(session, ms, m, matchers, parallelism=PARALLEL) mention_extractor.apply(docs) assert session.query(ms_doc).count() == 1 assert session.query(ms_cap).count() == 2 assert session.query(ms_sec).count() == 5 assert session.query(ms_tab).count() == 2 assert session.query(ms_fig).count() == 2 assert session.query(ms_para).count() == 30 assert session.query(ms_sent).count() == 35 assert session.query(ms_cell).count() == 21
def get_mentions(parser_output, first_time=False): # Dependencies obtained from parser output session = parser_output['session'] docs = parser_output['docs'] if not first_time: # Adding mention extractor required outputs to the parser_output_dict parser_output['mention_count'] = session.query(Mention).count() parser_output[ 'mention_variable'] = mention_definition.get_mention_list() return parser_output # Defining the mention, mention_space and matchers for mention extraction mentions = mention_definition.get_mention_list() mention_spaces = mention_space.get_mention_spaces() matchers = matcher.get_matchers() # Running the mention extractor on the parsed docs mention_extractor = MentionExtractor(session, mentions, mention_spaces, matchers) mention_extractor.apply(docs, parallelism=config.PARALLEL) # Adding mention extractor required outputs to the parser_output_dict parser_output['mention_count'] = session.query(Mention).count() parser_output['mention_variable'] = mentions return parser_output
def phone_extract(docs, session, phone_subclass, parallelism, clear=True): phone_lambda_matcher = LambdaFunctionMatcher(func=matcher_number_phone) regex_matcher = LambdaFunctionMatcher(func=regexMatch) phone_lamda_matcher = Union(regex_matcher, phone_lambda_matcher) phone_space = MentionPhoneNumber() mention_extractor = MentionExtractor(session, [phone_subclass], [phone_space], [phone_lamda_matcher]) mention_extractor.apply(docs, parallelism=parallelism, clear=clear)
def address_extract(docs, session, address_subclass, parallelism, clear=True): address_m1 = LambdaFunctionMatcher(func = has_province_address) address_m2 = LambdaFunctionMatcher(func = has_geographic_term_address) address_m3 = LambdaFunctionMatcher(func = address_prefix) address_m4 = LambdaFunctionMatcher(func = is_collection_of_number_and_geographical_term_and_provinces_name_address) address_m5 = LambdaFunctionMatcher(func = hasnt_ignor_words) address_matcher = Intersect(Union(address_m1, address_m2, address_m3), address_m4, address_m5) address_space = MentionSentences() mention_extractor = MentionExtractor(session, [address_subclass], [address_space], [address_matcher]) mention_extractor.apply(docs, parallelism=parallelism,clear=clear)
def test_visualizer(caplog): from fonduer.utils.visualizer import Visualizer # noqa """Unit test of visualizer using the md document. """ caplog.set_level(logging.INFO) session = Meta.init("postgresql://localhost:5432/" + ATTRIBUTE).Session() PARALLEL = 1 max_docs = 1 docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/md.pdf" # Preprocessor for the Docs doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the md document corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) # Grab the md document doc = session.query(Document).order_by(Document.name).all()[0] assert doc.name == "md" organization_ngrams = MentionNgrams(n_max=1) Org = mention_subclass("Org") organization_matcher = OrganizationMatcher() mention_extractor = MentionExtractor(session, [Org], [organization_ngrams], [organization_matcher]) mention_extractor.apply([doc], parallelism=PARALLEL) Organization = candidate_subclass("Organization", [Org]) candidate_extractor = CandidateExtractor(session, [Organization]) candidate_extractor.apply([doc], split=0, parallelism=PARALLEL) cands = session.query(Organization).filter(Organization.split == 0).all() # Test visualizer pdf_path = "tests/data/pdf_simple" vis = Visualizer(pdf_path) vis.display_candidates([cands[0]])
def birthday_extract(docs, session, birthday_subclass, parallelism, clear=True): filter_birthday_matcher = LambdaFunctionMatcher(func=filter_birthday, longest_match_only=True) birthday_conditions_matcher = LambdaFunctionMatcher( func=birthday_conditions, longest_match_only=True) birthday_matcher = Intersect(filter_birthday_matcher, birthday_conditions_matcher) birthday_space = MentionDates() mention_extractor = MentionExtractor(session, [birthday_subclass], [birthday_space], [birthday_matcher]) mention_extractor.apply(docs, parallelism=parallelism, clear=clear)
def test_ngrams(caplog): """Test ngram limits in mention extraction""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 503 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction Person = mention_subclass("Person") person_ngrams = MentionNgrams(n_max=3) person_matcher = PersonMatcher() mention_extractor = MentionExtractor( session, [Person], [person_ngrams], [person_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Person).count() == 118 mentions = session.query(Person).all() assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0 # Test for unigram exclusion person_ngrams = MentionNgrams(n_min=2, n_max=3) mention_extractor = MentionExtractor( session, [Person], [person_ngrams], [person_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Person).count() == 69 mentions = session.query(Person).all() assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
def name_extract(docs, session, name_subclass, parallelism, clear=True): length_name_matcher = LambdaFunctionMatcher(func=length_name) position_name_matcher = LambdaFunctionMatcher(func=position_name) capitalize_name_matcher = LambdaFunctionMatcher(func=capitalize_name) last_name_matcher = LambdaFunctionMatcher(func=last_name) name_common_matcher = LambdaFunctionMatcher(func=name_common) check_name_matcher = LambdaFunctionMatcher(func=check_name) prefix_name_matcher = LambdaFunctionMatcher(func=prefix_name) form_name_matcher = Intersect(length_name_matcher, position_name_matcher, capitalize_name_matcher) name_matcher = Intersect( Union(Intersect(last_name_matcher, form_name_matcher), Intersect(name_common_matcher, form_name_matcher), prefix_name_matcher), check_name_matcher) name_space = MentionName() mention_extractor = MentionExtractor(session, [name_subclass], [name_space], [name_matcher]) mention_extractor.apply(docs, parallelism=parallelism, clear=clear)
def test_row_col_ngram_extraction(caplog): """Test whether row/column ngrams list is empty, if mention is not in a table.""" caplog.set_level(logging.INFO) PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction place_ngrams = MentionNgramsTemp(n_max=4) Place = mention_subclass("Place") def get_row_and_column_ngrams(mention): row_ngrams = list(get_row_ngrams(mention)) col_ngrams = list(get_col_ngrams(mention)) if not mention.sentence.is_tabular(): assert len(row_ngrams) == 1 and row_ngrams[0] is None assert len(col_ngrams) == 1 and col_ngrams[0] is None else: assert not any(x is None for x in row_ngrams) assert not any(x is None for x in col_ngrams) if "birth_place" in row_ngrams: return True else: return False birthplace_matcher = LambdaFunctionMatcher(func=get_row_and_column_ngrams) mention_extractor = MentionExtractor( session, [Place], [place_ngrams], [birthplace_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL)
def test_cand_gen(caplog): """Test extracting candidates from mentions from documents.""" caplog.set_level(logging.INFO) if platform == "darwin": logger.info("Using single core.") PARALLEL = 1 else: logger.info("Using two cores.") PARALLEL = 2 # Travis only gives 2 cores def do_nothing_matcher(fig): return True max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 70 assert session.query(Volt).count() == 33 assert session.query(Temp).count() == 23 assert session.query(Fig).count() == 31 part = session.query(Part).order_by(Part.id).all()[0] volt = session.query(Volt).order_by(Volt.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Volt: {volt.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[ temp_throttler, volt_throttler, volt_throttler, ], # Fail, mismatched arity ) with pytest.raises(ValueError): candidate_extractor = CandidateExtractor( session, [PartTemp], # Fail, mismatched arity throttlers=[temp_throttler, volt_throttler], ) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt] ) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1610 assert session.query(PartVolt).count() == 2310 assert session.query(Candidate).count() == 3920 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(PartVolt).count() == 0 # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, None] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(PartVolt).count() == 2310 assert session.query(Candidate).count() == 3742 candidate_extractor.clear_all(split=0) assert session.query(Candidate).count() == 0 candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(PartVolt).count() == 1993 assert session.query(Candidate).count() == 3425 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].volts) == 33 assert len(docs[0].temps) == 23 # Test that deletion of a Candidate does not delete the Mention session.query(PartTemp).delete(synchronize_session="fetch") assert session.query(PartTemp).count() == 0 assert session.query(Temp).count() == 23 assert session.query(Part).count() == 70 # Test deletion of Candidate if Mention is deleted assert session.query(PartVolt).count() == 1993 assert session.query(Volt).count() == 33 session.query(Volt).delete(synchronize_session="fetch") assert session.query(Volt).count() == 0 assert session.query(PartVolt).count() == 0
def test_e2e(): """Run an end-to-end test on documents of the hardware domain.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 12 fonduer.init_logging( format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s", level=logging.INFO, ) session = fonduer.Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info(f"Sentences: {num_sentences}") # Divide into test and train docs = sorted(corpus_parser.get_documents()) last_docs = sorted(corpus_parser.get_last_documents()) ld = len(docs) assert ld == len(last_docs) assert len(docs[0].sentences) == len(last_docs[0].sentences) assert len(docs[0].sentences) == 799 assert len(docs[1].sentences) == 663 assert len(docs[2].sentences) == 784 assert len(docs[3].sentences) == 661 assert len(docs[4].sentences) == 513 assert len(docs[5].sentences) == 700 assert len(docs[6].sentences) == 528 assert len(docs[7].sentences) == 161 assert len(docs[8].sentences) == 228 assert len(docs[9].sentences) == 511 assert len(docs[10].sentences) == 331 assert len(docs[11].sentences) == 528 # Check table numbers assert len(docs[0].tables) == 9 assert len(docs[1].tables) == 9 assert len(docs[2].tables) == 14 assert len(docs[3].tables) == 11 assert len(docs[4].tables) == 11 assert len(docs[5].tables) == 10 assert len(docs[6].tables) == 10 assert len(docs[7].tables) == 2 assert len(docs[8].tables) == 7 assert len(docs[9].tables) == 10 assert len(docs[10].tables) == 6 assert len(docs[11].tables) == 9 # Check figure numbers assert len(docs[0].figures) == 32 assert len(docs[1].figures) == 11 assert len(docs[2].figures) == 38 assert len(docs[3].figures) == 31 assert len(docs[4].figures) == 7 assert len(docs[5].figures) == 38 assert len(docs[6].figures) == 10 assert len(docs[7].figures) == 31 assert len(docs[8].figures) == 4 assert len(docs[9].figures) == 27 assert len(docs[10].figures) == 5 assert len(docs[11].figures) == 27 # Check caption numbers assert len(docs[0].captions) == 0 assert len(docs[1].captions) == 0 assert len(docs[2].captions) == 0 assert len(docs[3].captions) == 0 assert len(docs[4].captions) == 0 assert len(docs[5].captions) == 0 assert len(docs[6].captions) == 0 assert len(docs[7].captions) == 0 assert len(docs[8].captions) == 0 assert len(docs[9].captions) == 0 assert len(docs[10].captions) == 0 assert len(docs[11].captions) == 0 train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 299 assert session.query(Temp).count() == 138 assert session.query(Volt).count() == 140 assert len(mention_extractor.get_mentions()) == 3 assert len(mention_extractor.get_mentions()[0]) == 299 assert (len( mention_extractor.get_mentions(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 70) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493 assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61 assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416 assert session.query(PartVolt).count() == 4282 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) assert len(train_cands) == 2 assert len(train_cands[0]) == 3493 assert (len( candidate_extractor.get_candidates(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 1432) # Featurization featurizer = Featurizer(session, [PartTemp, PartVolt]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 214 assert session.query(FeatureKey).count() == 1260 # Test Dropping FeatureKey # Should force a row deletion featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"]) assert session.query(FeatureKey).count() == 1259 # Should only remove the part_volt as a relation and leave part_temp assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) assert session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"] assert session.query(FeatureKey).count() == 1259 # Inserting the removed key featurizer.upsert_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt]) assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } assert session.query(FeatureKey).count() == 1259 # Removing the key again featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) # Removing the last relation from a key should delete the row featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp]) assert session.query(FeatureKey).count() == 1258 session.query(Feature).delete(synchronize_session="fetch") session.query(FeatureKey).delete(synchronize_session="fetch") featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 6478 assert session.query(FeatureKey).count() == 4538 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (3493, 4538) assert F_train[1].shape == (2985, 4538) assert len(featurizer.get_keys()) == 4538 featurizer.apply(split=1, parallelism=PARALLEL) assert session.query(Feature).count() == 6692 assert session.query(FeatureKey).count() == 4538 F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (61, 4538) assert F_dev[1].shape == (153, 4538) featurizer.apply(split=2, parallelism=PARALLEL) assert session.query(Feature).count() == 8252 assert session.query(FeatureKey).count() == 4538 F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (416, 4538) assert F_test[1].shape == (1144, 4538) gold_file = "tests/data/hardware_tutorial_gold.csv" labeler = Labeler(session, [PartTemp, PartVolt]) labeler.apply( docs=last_docs, lfs=[[gold], [gold]], table=GoldLabel, train=True, parallelism=PARALLEL, ) assert session.query(GoldLabel).count() == 8252 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] ce_v_max_lfs = [ LF_bad_keywords_in_row, LF_current_in_row, LF_non_ce_voltages_in_row, ] with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply( docs=train_docs, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL, ) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 9 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 9) assert L_train[1].shape == (2985, 9) assert len(labeler.get_keys()) == 9 # Test Dropping LabelerKey labeler.drop_keys(["LF_storage_row"]) assert len(labeler.get_keys()) == 8 # Test Upserting LabelerKey labeler.upsert_keys(["LF_storage_row"]) assert "LF_storage_row" in [label.name for label in labeler.get_keys()] L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (3493, 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (3493, 1) label_model = LabelModel() label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) # Collect word counter word_counter = collect_word_counter(train_cands) emmental.init(fonduer.Meta.log_path) # Training config config = { "meta_config": { "verbose": False }, "model_config": { "model_path": None, "device": 0, "dataparallel": False }, "learner_config": { "n_epochs": 5, "optimizer_config": { "lr": 0.001, "l2": 0.0 }, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": { f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min" }, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.Meta.update_config(config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=300, specials=specials) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0], emb_layer.word2id, 2), split="test", batch_size=100, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6) true_pred = [test_cands[0][_] for _ in positive[0]] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 16 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 16) label_model = LabelModel() label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100) train_marginals = label_model.predict_proba(L_train[0]) diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1) train_idxs = np.where(diffs > 1e-6)[0] train_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, train_marginals, train_idxs, ), split="train", batch_size=100, shuffle=True, ) valid_dataloader = EmmentalDataLoader( task_to_label_dict={ATTRIBUTE: "labels"}, dataset=FonduerDataset( ATTRIBUTE, train_cands[0], F_train[0], emb_layer.word2id, np.argmax(train_marginals, axis=1), train_idxs, ), split="valid", batch_size=100, shuffle=False, ) emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader, valid_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing LSTM emmental.Meta.reset() emmental.init(fonduer.Meta.log_path) emmental.Meta.update_config(config=config) tasks = create_task(ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LSTM") model = EmmentalModel(name=f"{ATTRIBUTE}_task") for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() emmental_learner.learn(model, [train_dataloader]) test_preds = model.predict(test_dataloader, return_preds=True) positive = np.where( np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7) true_pred = [test_cands[0][_] for _ in positive[0]] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7
def test_cand_gen_cascading_delete(caplog): """Test cascading the deletion of candidates.""" caplog.set_level(logging.INFO) if platform == "darwin": logger.info("Using single core.") PARALLEL = 1 else: logger.info("Using two cores.") PARALLEL = 2 # Travis only gives 2 cores max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Mention).count() == 93 assert session.query(Part).count() == 70 assert session.query(Temp).count() == 23 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(Candidate).count() == 1432 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].temps) == 23 # Delete from parent class should cascade to child x = session.query(Candidate).first() session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(Candidate).count() == 1431 assert session.query(PartTemp).count() == 1431 # Clearing Mentions should also delete Candidates mention_extractor.clear() assert session.query(Mention).count() == 0 assert session.query(Part).count() == 0 assert session.query(Temp).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(Candidate).count() == 0
def main( conn_string, gain=False, current=False, max_docs=float("inf"), parse=False, first_time=False, re_label=False, gpu=None, parallel=8, log_dir="logs", verbose=False, ): # Setup initial configuration if gpu: os.environ["CUDA_VISIBLE_DEVICES"] = gpu if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) rel_list = [] if gain: rel_list.append("gain") if current: rel_list.append("current") logger.info(f"=" * 30) logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}") session = Meta.init(conn_string).Session() # Parsing start = timer() logger.info(f"Starting parsing...") docs, train_docs, dev_docs, test_docs = parse_dataset(session, dirname, first_time=parse, parallel=parallel, max_docs=max_docs) logger.debug(f"Done") end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of Documents: {len(docs)}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") # Mention Extraction start = timer() mentions = [] ngrams = [] matchers = [] # Only do those that are enabled if gain: Gain = mention_subclass("Gain") gain_matcher = get_gain_matcher() gain_ngrams = MentionNgrams(n_max=2) mentions.append(Gain) ngrams.append(gain_ngrams) matchers.append(gain_matcher) if current: Current = mention_subclass("SupplyCurrent") current_matcher = get_supply_current_matcher() current_ngrams = MentionNgramsCurrent(n_max=3) mentions.append(Current) ngrams.append(current_ngrams) matchers.append(current_matcher) mention_extractor = MentionExtractor(session, mentions, ngrams, matchers) if first_time: mention_extractor.apply(docs, parallelism=parallel) logger.info(f"Total Mentions: {session.query(Mention).count()}") if gain: logger.info(f"Total Gain: {session.query(Gain).count()}") if current: logger.info(f"Total Current: {session.query(Current).count()}") cand_classes = [] if gain: GainCand = candidate_subclass("GainCand", [Gain]) cand_classes.append(GainCand) if current: CurrentCand = candidate_subclass("CurrentCand", [Current]) cand_classes.append(CurrentCand) candidate_extractor = CandidateExtractor(session, cand_classes) if first_time: for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=parallel) train_cands = candidate_extractor.get_candidates(split=0) dev_cands = candidate_extractor.get_candidates(split=1) test_cands = candidate_extractor.get_candidates(split=2) logger.info( f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}") logger.info( f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}") logger.info( f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}") logger.info("Done w/ candidate extraction.") end = timer() logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}") # First, check total recall # result = entity_level_scores(dev_cands[0], corpus=dev_docs) # logger.info(f"Gain Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores(test_cands[0], corpus=test_docs) # logger.info(f"Gain Total Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # # result = entity_level_scores(dev_cands[1], corpus=dev_docs, is_gain=False) # logger.info(f"Current Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores(test_cands[1], corpus=test_docs, is_gain=False) # logger.info(f"Current Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") start = timer() featurizer = Featurizer(session, cand_classes) if first_time: logger.info("Starting featurizer...") featurizer.apply(split=0, train=True, parallelism=parallel) featurizer.apply(split=1, parallelism=parallel) featurizer.apply(split=2, parallelism=parallel) logger.info("Done") logger.info("Getting feature matrices...") # Serialize feature matrices on first run if first_time: F_train = featurizer.get_feature_matrices(train_cands) F_dev = featurizer.get_feature_matrices(dev_cands) F_test = featurizer.get_feature_matrices(test_cands) end = timer() logger.warning( f"Featurization Time (min): {((end - start) / 60.0):.1f}") pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb")) pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb")) pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb")) else: F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb")) F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb")) F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb")) logger.info("Done.") start = timer() logger.info("Labeling training data...") labeler = Labeler(session, cand_classes) lfs = [] if gain: lfs.append(gain_lfs) if current: lfs.append(current_lfs) if first_time: logger.info("Applying LFs...") labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel) elif re_label: logger.info("Re-applying LFs...") labeler.update(split=0, lfs=lfs, parallelism=parallel) logger.info("Done...") logger.info("Getting label matrices...") L_train = labeler.get_label_matrices(train_cands) logger.info("Done...") end = timer() logger.warning( f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}") if gain: relation = "gain" idx = rel_list.index(relation) logger.info("Score Gain.") dev_gold_entities = get_gold_set(is_gain=True) L_dev_gt = [] for c in dev_cands[idx]: flag = FALSE for entity in cand_to_entity(c, is_gain=True): if entity in dev_gold_entities: flag = TRUE L_dev_gt.append(flag) marginals = generative_model(L_train[idx]) disc_models = discriminative_model( train_cands[idx], F_train[idx], marginals, X_dev=(dev_cands[idx], F_dev[idx]), Y_dev=L_dev_gt, n_epochs=500, gpu=gpu, ) best_result, best_b = scoring(disc_models, test_cands[idx], test_docs, F_test[idx], num=50) print_scores(relation, best_result, best_b) logger.info("Output CSV files for Opo and Digi-key Analysis.") Y_prob = disc_models.marginals((train_cands[idx], F_train[idx])) output_csv(train_cands[idx], Y_prob, is_gain=True) Y_prob = disc_models.marginals((test_cands[idx], F_test[idx])) output_csv(test_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(test_cands[idx], Y_prob, "gain_test_probs.csv", is_gain=True) Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx])) output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(dev_cands[idx], Y_prob, "gain_dev_probs.csv", is_gain=True) if current: relation = "current" idx = rel_list.index(relation) logger.info("Score Current.") dev_gold_entities = get_gold_set(is_gain=False) L_dev_gt = [] for c in dev_cands[idx]: flag = FALSE for entity in cand_to_entity(c, is_gain=False): if entity in dev_gold_entities: flag = TRUE L_dev_gt.append(flag) marginals = generative_model(L_train[idx]) disc_models = discriminative_model( train_cands[idx], F_train[idx], marginals, X_dev=(dev_cands[idx], F_dev[idx]), Y_dev=L_dev_gt, n_epochs=100, gpu=gpu, ) best_result, best_b = scoring(disc_models, test_cands[idx], test_docs, F_test[idx], is_gain=False, num=50) print_scores(relation, best_result, best_b) logger.info("Output CSV files for Opo and Digi-key Analysis.") # Dump CSV files for digi-key analysis and Opo comparison Y_prob = disc_models.marginals((train_cands[idx], F_train[idx])) output_csv(train_cands[idx], Y_prob, is_gain=False) Y_prob = disc_models.marginals((test_cands[idx], F_test[idx])) output_csv(test_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(test_cands[idx], Y_prob, "current_test_probs.csv", is_gain=False) Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx])) output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(dev_cands[idx], Y_prob, "current_dev_probs.csv", is_gain=False) end = timer() logger.warning( f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}" )
def test_incremental(caplog): """Run an end-to-end test on incremental additions.""" caplog.set_level(logging.INFO) PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/dtc114w.html" pdf_path = "tests/data/pdf/dtc114w.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs docs = corpus_parser.get_documents() last_docs = corpus_parser.get_documents() assert len(docs[0].sentences) == len(last_docs[0].sentences) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp], throttlers=[temp_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70 assert session.query(Candidate).count() == 70 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 70 # Featurization featurizer = Featurizer(session, [PartTemp]) featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 70 assert session.query(FeatureKey).count() == 512 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (70, 512) assert len(featurizer.get_keys()) == 512 # Test Dropping FeatureKey featurizer.drop_keys(["CORE_e1_LENGTH_1"]) assert session.query(FeatureKey).count() == 512 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 70 # Only 5 because LF_operating_row doesn't apply to the first test doc assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (70, 5) assert len(labeler.get_keys()) == 5 docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False) assert len(corpus_parser.get_documents()) == 2 new_docs = corpus_parser.get_last_documents() assert len(new_docs) == 1 assert new_docs[0].name == "112823" # Get mentions from just the new docs mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Just run candidate extraction and assign to split 0 candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1502 # Update features featurizer.update(new_docs, parallelism=PARALLEL) assert session.query(Feature).count() == 1502 assert session.query(FeatureKey).count() == 2573 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (1502, 2573) assert len(featurizer.get_keys()) == 2573 # Update Labels labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 1502 assert session.query(LabelKey).count() == 6 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (1502, 6) # Test clear featurizer.clear(train=True) assert session.query(FeatureKey).count() == 0
def test_incremental(): """Run an end-to-end test on incremental additions.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/dtc114w.html" pdf_path = "tests/data/pdf/dtc114w.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs docs = corpus_parser.get_documents() last_docs = corpus_parser.get_documents() assert len(docs[0].sentences) == len(last_docs[0].sentences) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 8 # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp], throttlers=[temp_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70 assert session.query(Candidate).count() == 70 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 70 # Featurization featurizer = Featurizer(session, [PartTemp]) featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 70 assert session.query(FeatureKey).count() == 512 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (70, 512) assert len(featurizer.get_keys()) == 512 # Test Dropping FeatureKey featurizer.drop_keys(["CORE_e1_LENGTH_1"]) assert session.query(FeatureKey).count() == 512 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 70 # Only 5 because LF_operating_row doesn't apply to the first test doc assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (70, 5) assert len(labeler.get_keys()) == 5 docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False) assert len(corpus_parser.get_documents()) == 2 new_docs = corpus_parser.get_last_documents() assert len(new_docs) == 1 assert new_docs[0].name == "112823" # Get mentions from just the new docs mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Test if existing mentions are skipped. mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 31 # Just run candidate extraction and assign to split 0 candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1502 # Test if existing candidates are skipped. candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1502 # Update features featurizer.update(new_docs, parallelism=PARALLEL) assert session.query(Feature).count() == 1502 assert session.query(FeatureKey).count() == 2573 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (1502, 2573) assert len(featurizer.get_keys()) == 2573 # Update LF_storage_row. Now it always returns ABSTAIN. @labeling_function(name="LF_storage_row") def LF_storage_row_updated(c): return ABSTAIN stg_temp_lfs = [ LF_storage_row_updated, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] # Update Labels labeler.update(docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 1502 # Only 5 because LF_storage_row doesn't apply to any doc (always ABSTAIN) assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (1502, 5) # Test clear featurizer.clear(train=True) assert session.query(FeatureKey).count() == 0
def main( conn_string, gain=False, current=False, max_docs=float("inf"), parse=False, first_time=False, re_label=False, parallel=8, log_dir="logs", verbose=False, ): # Setup initial configuration if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) rel_list = [] if gain: rel_list.append("gain") if current: rel_list.append("current") logger.info(f"=" * 30) logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}") session = Meta.init(conn_string).Session() # Parsing start = timer() logger.info(f"Starting parsing...") docs, train_docs, dev_docs, test_docs = parse_dataset(session, dirname, first_time=parse, parallel=parallel, max_docs=max_docs) logger.debug(f"Done") end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of Documents: {len(docs)}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") # Mention Extraction start = timer() mentions = [] ngrams = [] matchers = [] # Only do those that are enabled if gain: Gain = mention_subclass("Gain") gain_matcher = get_gain_matcher() gain_ngrams = MentionNgrams(n_max=2) mentions.append(Gain) ngrams.append(gain_ngrams) matchers.append(gain_matcher) if current: Current = mention_subclass("SupplyCurrent") current_matcher = get_supply_current_matcher() current_ngrams = MentionNgramsCurrent(n_max=3) mentions.append(Current) ngrams.append(current_ngrams) matchers.append(current_matcher) mention_extractor = MentionExtractor(session, mentions, ngrams, matchers) if first_time: mention_extractor.apply(docs, parallelism=parallel) logger.info(f"Total Mentions: {session.query(Mention).count()}") if gain: logger.info(f"Total Gain: {session.query(Gain).count()}") if current: logger.info(f"Total Current: {session.query(Current).count()}") cand_classes = [] if gain: GainCand = candidate_subclass("GainCand", [Gain]) cand_classes.append(GainCand) if current: CurrentCand = candidate_subclass("CurrentCand", [Current]) cand_classes.append(CurrentCand) candidate_extractor = CandidateExtractor(session, cand_classes) if first_time: for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=parallel) # These must be sorted for deterministic behavior. train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) logger.info( f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}") logger.info( f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}") logger.info( f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}") logger.info("Done w/ candidate extraction.") end = timer() logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}") # First, check total recall # result = entity_level_scores( # candidates_to_entities(dev_cands[0], is_gain=True), # corpus=dev_docs, # is_gain=True, # ) # logger.info(f"Gain Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores( # candidates_to_entities(test_cands[0], is_gain=True), # corpus=test_docs, # is_gain=True, # ) # logger.info(f"Gain Total Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # # result = entity_level_scores( # candidates_to_entities(dev_cands[1], is_gain=False), # corpus=dev_docs, # is_gain=False, # ) # logger.info(f"Current Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores( # candidates_to_entities(test_cands[1], is_gain=False), # corpus=test_docs, # is_gain=False, # ) # logger.info(f"Current Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") start = timer() # Using parallelism = 1 for deterministic behavior. featurizer = Featurizer(session, cand_classes, parallelism=1) if first_time: logger.info("Starting featurizer...") # Set feature space based on dev set, which we use for training rather # than the large train set. featurizer.apply(split=1, train=True) featurizer.apply(split=0) featurizer.apply(split=2) logger.info("Done") logger.info("Getting feature matrices...") # Serialize feature matrices on first run if first_time: F_train = featurizer.get_feature_matrices(train_cands) F_dev = featurizer.get_feature_matrices(dev_cands) F_test = featurizer.get_feature_matrices(test_cands) end = timer() logger.warning( f"Featurization Time (min): {((end - start) / 60.0):.1f}") F_train_dict = {} F_dev_dict = {} F_test_dict = {} for idx, relation in enumerate(rel_list): F_train_dict[relation] = F_train[idx] F_dev_dict[relation] = F_dev[idx] F_test_dict[relation] = F_test[idx] pickle.dump(F_train_dict, open(os.path.join(dirname, "F_train_dict.pkl"), "wb")) pickle.dump(F_dev_dict, open(os.path.join(dirname, "F_dev_dict.pkl"), "wb")) pickle.dump(F_test_dict, open(os.path.join(dirname, "F_test_dict.pkl"), "wb")) else: F_train_dict = pickle.load( open(os.path.join(dirname, "F_train_dict.pkl"), "rb")) F_dev_dict = pickle.load( open(os.path.join(dirname, "F_dev_dict.pkl"), "rb")) F_test_dict = pickle.load( open(os.path.join(dirname, "F_test_dict.pkl"), "rb")) F_train = [] F_dev = [] F_test = [] for relation in rel_list: F_train.append(F_train_dict[relation]) F_dev.append(F_dev_dict[relation]) F_test.append(F_test_dict[relation]) logger.info("Done.") start = timer() logger.info("Labeling training data...") # labeler = Labeler(session, cand_classes) # lfs = [] # if gain: # lfs.append(gain_lfs) # # if current: # lfs.append(current_lfs) # # if first_time: # logger.info("Applying LFs...") # labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel) # elif re_label: # logger.info("Re-applying LFs...") # labeler.update(split=0, lfs=lfs, parallelism=parallel) # # logger.info("Done...") # logger.info("Getting label matrices...") # L_train = labeler.get_label_matrices(train_cands) # logger.info("Done...") if first_time: marginals_dict = {} for idx, relation in enumerate(rel_list): # Manually create marginals from human annotations marginal = [] dev_gold_entities = get_gold_set(is_gain=(relation == "gain")) for c in dev_cands[idx]: flag = False for entity in cand_to_entity(c, is_gain=(relation == "gain")): if entity in dev_gold_entities: flag = True if flag: marginal.append([0.0, 1.0]) else: marginal.append([1.0, 0.0]) marginals_dict[relation] = np.array(marginal) pickle.dump(marginals_dict, open(os.path.join(dirname, "marginals_dict.pkl"), "wb")) else: marginals_dict = pickle.load( open(os.path.join(dirname, "marginals_dict.pkl"), "rb")) marginals = [] for relation in rel_list: marginals.append(marginals_dict[relation]) end = timer() logger.warning( f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}") start = timer() word_counter = collect_word_counter(train_cands) # Training config config = { "meta_config": { "verbose": True, "seed": 30 }, "model_config": { "model_path": None, "device": 0, "dataparallel": False }, "learner_config": { "n_epochs": 500, "optimizer_config": { "lr": 0.001, "l2": 0.005 }, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": { "model/all/train/loss": "min" }, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.init(log_dir=Meta.log_path, config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=300, specials=specials) train_idxs = [] train_dataloader = [] for idx, relation in enumerate(rel_list): diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1) train_idxs.append(np.where(diffs > 1e-6)[0]) # only uses dev set as training data, with human annotations train_dataloader.append( EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset( relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, marginals[idx], train_idxs[idx], ), split="train", batch_size=256, shuffle=True, )) num_feature_keys = len(featurizer.get_keys()) model = EmmentalModel(name=f"opamp_tasks") # List relation names, arities, list of classes tasks = create_task( rel_list, [2] * len(rel_list), num_feature_keys, [2] * len(rel_list), emb_layer, model="LogisticRegression", ) for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() # If given a list of multi, will train on multiple emmental_learner.learn(model, train_dataloader) # List of dataloader for each relation for idx, relation in enumerate(rel_list): test_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, test_cands[idx], F_test[idx], emb_layer.word2id, 2), split="test", batch_size=256, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) best_result, best_b = scoring( test_preds, test_cands[idx], test_docs, is_gain=(relation == "gain"), num=100, ) # Dump CSV files for analysis if relation == "gain": train_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, train_cands[idx], F_train[idx], emb_layer.word2id, 2), split="train", batch_size=256, shuffle=False, ) train_preds = model.predict(train_dataloader, return_preds=True) Y_prob = np.array(train_preds["probs"][relation])[:, TRUE] output_csv(train_cands[idx], Y_prob, is_gain=True) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] output_csv(test_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(test_cands[idx], Y_prob, "gain_test_probs.csv", is_gain=True) dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=256, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(dev_cands[idx], Y_prob, "gain_dev_probs.csv", is_gain=True) if relation == "current": train_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, train_cands[idx], F_train[idx], emb_layer.word2id, 2), split="train", batch_size=256, shuffle=False, ) train_preds = model.predict(train_dataloader, return_preds=True) Y_prob = np.array(train_preds["probs"][relation])[:, TRUE] output_csv(train_cands[idx], Y_prob, is_gain=False) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] output_csv(test_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(test_cands[idx], Y_prob, "current_test_probs.csv", is_gain=False) dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=256, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(dev_cands[idx], Y_prob, "current_dev_probs.csv", is_gain=False) end = timer() logger.warning( f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}" )
def main( conn_string, max_docs=float("inf"), parse=False, first_time=False, gpu=None, parallel=4, log_dir=None, verbose=False, ): if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) session = Meta.init(conn_string).Session() # Parsing logger.info(f"Starting parsing...") start = timer() docs, train_docs, dev_docs, test_docs = parse_dataset( session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs ) end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") start = timer() Thumbnails = mention_subclass("Thumbnails") thumbnails_img = MentionFigures() class HasFigures(_Matcher): def _f(self, m): file_path = "" for prefix in [ f"{dirname}/data/train/html/", f"{dirname}/data/dev/html/", f"{dirname}/data/test/html/", ]: if os.path.exists(prefix + m.figure.url): file_path = prefix + m.figure.url if file_path == "": return False img = Image.open(file_path) width, height = img.size min_value = min(width, height) return min_value > 50 mention_extractor = MentionExtractor( session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel ) if first_time: mention_extractor.apply(docs) logger.info("Total Mentions: {}".format(session.query(Mention).count())) ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails]) candidate_extractor = CandidateExtractor( session, [ThumbnailLabel], throttlers=[None], parallelism=parallel ) if first_time: candidate_extractor.apply(train_docs, split=0) candidate_extractor.apply(dev_docs, split=1) candidate_extractor.apply(test_docs, split=2) train_cands = candidate_extractor.get_candidates(split=0) # Sort the dev_cands, which are used for training, for deterministic behavior dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2) end = timer() logger.warning(f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}") logger.info("Total train candidate:\t{}".format(len(train_cands[0]))) logger.info("Total dev candidate:\t{}".format(len(dev_cands[0]))) logger.info("Total test candidate:\t{}".format(len(test_cands[0]))) fin = open(f"{dirname}/data/ground_truth.txt", "r") gt = set() for line in fin: gt.add("::".join(line.lower().split())) fin.close() # Labeling start = timer() def LF_gt_label(c): doc_file_id = ( f"{c[0].context.figure.document.name.lower()}.pdf::" f"{os.path.basename(c[0].context.figure.url.lower())}" ) return TRUE if doc_file_id in gt else FALSE gt_dev = [LF_gt_label(cand) for cand in dev_cands[0]] gt_test = [LF_gt_label(cand) for cand in test_cands[0]] end = timer() logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}") batch_size = 64 input_size = 224 K = 2 emmental.init(log_dir=Meta.log_path, config=emmental_config) emmental.Meta.config["learner_config"]["task_scheduler_config"][ "task_scheduler" ] = DauphinScheduler(augment_k=K, enlarge=1) train_dataset = ThumbnailDataset( "Thumbnail", dev_cands[0], gt_dev, "train", prob_label=True, prefix=f"{dirname}/data/dev/html/", input_size=input_size, transform_cls=Augmentation(2), k=K, ) val_dataset = ThumbnailDataset( "Thumbnail", dev_cands[0], gt_dev, "valid", prob_label=False, prefix=f"{dirname}/data/dev/html/", input_size=input_size, k=1, ) test_dataset = ThumbnailDataset( "Thumbnail", test_cands[0], gt_test, "test", prob_label=False, prefix=f"{dirname}/data/test/html/", input_size=input_size, k=1, ) dataloaders = [] dataloaders.append( EmmentalDataLoader( task_to_label_dict={"Thumbnail": "labels"}, dataset=train_dataset, split="train", shuffle=True, batch_size=batch_size, num_workers=1, ) ) dataloaders.append( EmmentalDataLoader( task_to_label_dict={"Thumbnail": "labels"}, dataset=val_dataset, split="valid", shuffle=False, batch_size=batch_size, num_workers=1, ) ) dataloaders.append( EmmentalDataLoader( task_to_label_dict={"Thumbnail": "labels"}, dataset=test_dataset, split="test", shuffle=False, batch_size=batch_size, num_workers=1, ) ) model = EmmentalModel(name=f"Thumbnail") model.add_task( create_task("Thumbnail", n_class=2, model="resnet18", pretrained=True) ) emmental_learner = EmmentalLearner() emmental_learner.learn(model, dataloaders) scores = model.score(dataloaders) logger.warning("Model Score:") logger.warning(f"precision: {scores['Thumbnail/Thumbnail/test/precision']:.3f}") logger.warning(f"recall: {scores['Thumbnail/Thumbnail/test/recall']:.3f}") logger.warning(f"f1: {scores['Thumbnail/Thumbnail/test/f1']:.3f}")
def test_cand_gen_cascading_delete(): """Test cascading the deletion of candidates.""" # GitHub Actions gives 2 cores # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners PARALLEL = 2 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Mention).count() == 93 assert session.query(Part).count() == 70 assert session.query(Temp).count() == 23 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 1432 assert session.query(Candidate).count() == 1432 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].temps) == 23 # Delete from parent class should cascade to child x = session.query(Candidate).first() session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(Candidate).count() == 1431 assert session.query(PartTemp).count() == 1431 # Test that deletion of a Candidate does not delete the Mention x = session.query(PartTemp).first() session.query(PartTemp).filter_by(id=x.id).delete(synchronize_session="fetch") assert session.query(PartTemp).count() == 1430 assert session.query(Temp).count() == 23 assert session.query(Part).count() == 70 # Clearing Mentions should also delete Candidates mention_extractor.clear() assert session.query(Mention).count() == 0 assert session.query(Part).count() == 0 assert session.query(Temp).count() == 0 assert session.query(PartTemp).count() == 0 assert session.query(Candidate).count() == 0
def test_cand_gen(caplog): """Test extracting candidates from mentions from documents.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parseing if os.name == "posix": PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 10 session = Meta.init("postgres://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing num_docs = session.query(Document).count() if num_docs != max_docs: logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 5892 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") with pytest.raises(ValueError): mention_extractor = MentionExtractor( [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): mention_extractor = MentionExtractor( [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor = MentionExtractor( [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 234 assert session.query(Volt).count() == 108 assert session.query(Temp).count() == 118 part = session.query(Part).order_by(Part.id).all()[0] volt = session.query(Volt).order_by(Volt.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info("Part: {}".format(part.span)) logger.info("Volt: {}".format(volt.span)) logger.info("Temp: {}".format(temp.span)) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) candidate_extractor = CandidateExtractor( [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).count() == 3385 assert session.query(PartVolt).count() == 3364 assert session.query(Candidate).count() == 6749 assert docs[0].name == "112823" assert len(docs[0].parts) == 70 assert len(docs[0].volts) == 33 assert len(docs[0].temps) == 18 # Test that deletion of a Candidate does not delete the Mention session.query(PartTemp).delete() assert session.query(PartTemp).count() == 0 assert session.query(Temp).count() == 118 assert session.query(Part).count() == 234 # Test deletion of Candidate if Mention is deleted assert session.query(PartVolt).count() == 3364 assert session.query(Volt).count() == 108 session.query(Volt).delete() assert session.query(Volt).count() == 0 assert session.query(PartVolt).count() == 0
def test_mention_longest_match(caplog): """Test longest match filtering in mention extraction.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parsing PARALLEL = 1 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction name_ngrams = MentionNgramsPart(n_max=3) place_ngrams = MentionNgramsTemp(n_max=4) Name = mention_subclass("Name") Place = mention_subclass("Place") def is_birthplace_table_row(mention): if not mention.sentence.is_tabular(): return False ngrams = get_row_ngrams(mention, lower=True) if "birth_place" in ngrams: return True else: return False birthplace_matcher = LambdaFunctionMatcher( func=is_birthplace_table_row, longest_match_only=False ) mention_extractor = MentionExtractor( session, [Name, Place], [name_ngrams, place_ngrams], [PersonMatcher(), birthplace_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) mentions = session.query(Place).all() mention_spans = [x.context.get_span() for x in mentions] assert "Sinking Spring Farm" in mention_spans assert "Farm" in mention_spans assert len(mention_spans) == 23 birthplace_matcher = LambdaFunctionMatcher( func=is_birthplace_table_row, longest_match_only=True ) mention_extractor = MentionExtractor( session, [Name, Place], [name_ngrams, place_ngrams], [PersonMatcher(), birthplace_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) mentions = session.query(Place).all() mention_spans = [x.context.get_span() for x in mentions] assert "Sinking Spring Farm" in mention_spans assert "Farm" not in mention_spans assert len(mention_spans) == 4
def main( conn_string, max_docs=float("inf"), parse=False, first_time=False, gpu=None, parallel=4, log_dir=None, verbose=False, ): if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) tuner_config = {"max_search": 3} em_config = { # GENERAL "seed": None, "verbose": True, "show_plots": True, # Network # The first value is the output dim of the input module (or the sum of # the output dims of all the input modules if multitask=True and # multiple input modules are provided). The last value is the # output dim of the head layer (i.e., the cardinality of the # classification task). The remaining values are the output dims of # middle layers (if any). The number of middle layers will be inferred # from this list. # "layer_out_dims": [10, 2], # Input layer configs "input_layer_config": { "input_relu": False, "input_batchnorm": False, "input_dropout": 0.0, }, # Middle layer configs "middle_layer_config": { "middle_relu": False, "middle_batchnorm": False, "middle_dropout": 0.0, }, # Can optionally skip the head layer completely, for e.g. running baseline # models... "skip_head": True, # GPU "device": "cpu", # MODEL CLASS "resnet18" # DATA CONFIG "src": "gm", # TRAINING "train_config": { # Display "print_every": 1, # Print after this many epochs "disable_prog_bar": False, # Disable progress bar each epoch # Dataloader "data_loader_config": { "batch_size": 32, "num_workers": 8, "sampler": None }, # Loss weights "loss_weights": [0.5, 0.5], # Train Loop "n_epochs": 20, # 'grad_clip': 0.0, "l2": 0.0, # "lr": 0.01, "validation_metric": "accuracy", "validation_freq": 1, # Evaluate dev for during training every this many epochs # Optimizer "optimizer_config": { "optimizer": "adam", "optimizer_common": { "lr": 0.01 }, # Optimizer - SGD "sgd_config": { "momentum": 0.9 }, # Optimizer - Adam "adam_config": { "betas": (0.9, 0.999) }, }, # Scheduler "scheduler_config": { "scheduler": "reduce_on_plateau", # ['constant', 'exponential', 'reduce_on_plateu'] # Freeze learning rate initially this many epochs "lr_freeze": 0, # Scheduler - exponential "exponential_config": { "gamma": 0.9 }, # decay rate # Scheduler - reduce_on_plateau "plateau_config": { "factor": 0.5, "patience": 1, "threshold": 0.0001, "min_lr": 1e-5, }, }, # Checkpointer "checkpoint": True, "checkpoint_config": { "checkpoint_min": -1, # The initial best score to beat to merit checkpointing "checkpoint_runway": 0, # Don't start taking checkpoints until after this many epochs }, }, } session = Meta.init(conn_string).Session() os.chdir(os.path.dirname(os.path.abspath(__file__))) logger.info(f"CWD: {os.getcwd()}") dirname = "." docs, train_docs, dev_docs, test_docs = parse_dataset( session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs) logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") Thumbnails = mention_subclass("Thumbnails") thumbnails_img = MentionFigures() class HasFigures(_Matcher): def _f(self, m): file_path = "" for prefix in [ "data/train/html/", "data/dev/html/", "data/test/html/" ]: if os.path.exists(prefix + m.figure.url): file_path = prefix + m.figure.url if file_path == "": return False img = Image.open(file_path) width, height = img.size min_value = min(width, height) return min_value > 50 mention_extractor = MentionExtractor(session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel) if first_time: mention_extractor.apply(docs) logger.info("Total Mentions: {}".format(session.query(Mention).count())) ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails]) candidate_extractor = CandidateExtractor(session, [ThumbnailLabel], throttlers=[None], parallelism=parallel) if first_time: candidate_extractor.apply(train_docs, split=0) candidate_extractor.apply(dev_docs, split=1) candidate_extractor.apply(test_docs, split=2) train_cands = candidate_extractor.get_candidates(split=0) dev_cands = candidate_extractor.get_candidates(split=1) test_cands = candidate_extractor.get_candidates(split=2) logger.info("Total train candidate:\t{}".format(len(train_cands[0]))) logger.info("Total dev candidate:\t{}".format(len(dev_cands[0]))) logger.info("Total test candidate:\t{}".format(len(test_cands[0]))) fin = open("data/ground_truth.txt", "r") gt = set() for line in fin: gt.add("::".join(line.lower().split())) fin.close() def LF_gt_label(c): doc_file_id = (f"{c[0].context.figure.document.name.lower()}.pdf::" f"{os.path.basename(c[0].context.figure.url.lower())}") return TRUE if doc_file_id in gt else FALSE ans = {0: 0, 1: 0, 2: 0} gt_dev_pb = [] gt_dev = [] gt_test = [] for cand in dev_cands[0]: if LF_gt_label(cand) == 1: ans[1] += 1 gt_dev_pb.append([1.0, 0.0]) gt_dev.append(1.0) else: ans[2] += 1 gt_dev_pb.append([0.0, 1.0]) gt_dev.append(2.0) ans = {0: 0, 1: 0, 2: 0} for cand in test_cands[0]: gt_test.append(LF_gt_label(cand)) ans[gt_test[-1]] += 1 batch_size = 64 input_size = 224 train_loader = torch.utils.data.DataLoader( ImageList( data=dev_cands[0], label=torch.Tensor(gt_dev_pb), transform=transform(input_size), prefix="data/dev/html/", ), batch_size=batch_size, shuffle=False, ) dev_loader = torch.utils.data.DataLoader( ImageList( data=dev_cands[0], label=gt_dev, transform=transform(input_size), prefix="data/dev/html/", ), batch_size=batch_size, shuffle=False, ) test_loader = torch.utils.data.DataLoader( ImageList( data=test_cands[0], label=gt_test, transform=transform(input_size), prefix="data/test/html/", ), batch_size=100, shuffle=False, ) search_space = { "l2": [0.001, 0.0001, 0.00001], # linear range "lr": { "range": [0.0001, 0.1], "scale": "log" }, # log range } train_config = em_config["train_config"] # Defining network parameters num_classes = 2 # fc_size = 2 # hidden_size = 2 pretrained = True # Set CUDA device if gpu: em_config["device"] = "cuda" torch.cuda.set_device(int(gpu)) # Initializing input module input_module = get_cnn("resnet18", pretrained=pretrained, num_classes=num_classes) # Initializing model object init_args = [[num_classes]] init_kwargs = {"input_module": input_module} init_kwargs.update(em_config) # Searching model log_config = { "log_dir": os.path.join(dirname, log_dir), "run_name": "image" } searcher = RandomSearchTuner(EndModel, **log_config) end_model = searcher.search( search_space, dev_loader, train_args=[train_loader], init_args=init_args, init_kwargs=init_kwargs, train_kwargs=train_config, max_search=tuner_config["max_search"], ) # Evaluating model scores = end_model.score( test_loader, metric=["accuracy", "precision", "recall", "f1"], break_ties="abstain", ) logger.warning("End Model Score:") logger.warning(f"precision: {scores[1]:.3f}") logger.warning(f"recall: {scores[2]:.3f}") logger.warning(f"f1: {scores[3]:.3f}")
def test_e2e(caplog): """Run an end-to-end test on documents of the hardware domain.""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 12 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info("Docs: {}".format(num_docs)) assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info("Sentences: {}".format(num_sentences)) # Divide into test and train docs = corpus_parser.get_documents() ld = len(docs) assert ld == len(corpus_parser.get_last_documents()) assert len(docs[0].sentences) == 799 assert len(docs[1].sentences) == 663 assert len(docs[2].sentences) == 784 assert len(docs[3].sentences) == 661 assert len(docs[4].sentences) == 513 assert len(docs[5].sentences) == 700 assert len(docs[6].sentences) == 528 assert len(docs[7].sentences) == 161 assert len(docs[8].sentences) == 228 assert len(docs[9].sentences) == 511 assert len(docs[10].sentences) == 331 assert len(docs[11].sentences) == 528 # Check table numbers assert len(docs[0].tables) == 9 assert len(docs[1].tables) == 9 assert len(docs[2].tables) == 14 assert len(docs[3].tables) == 11 assert len(docs[4].tables) == 11 assert len(docs[5].tables) == 10 assert len(docs[6].tables) == 10 assert len(docs[7].tables) == 2 assert len(docs[8].tables) == 7 assert len(docs[9].tables) == 10 assert len(docs[10].tables) == 6 assert len(docs[11].tables) == 9 # Check figure numbers assert len(docs[0].figures) == 32 assert len(docs[1].figures) == 11 assert len(docs[2].figures) == 38 assert len(docs[3].figures) == 31 assert len(docs[4].figures) == 7 assert len(docs[5].figures) == 38 assert len(docs[6].figures) == 10 assert len(docs[7].figures) == 31 assert len(docs[8].figures) == 4 assert len(docs[9].figures) == 27 assert len(docs[10].figures) == 5 assert len(docs[11].figures) == 27 # Check caption numbers assert len(docs[0].captions) == 0 assert len(docs[1].captions) == 0 assert len(docs[2].captions) == 0 assert len(docs[3].captions) == 0 assert len(docs[4].captions) == 0 assert len(docs[5].captions) == 0 assert len(docs[6].captions) == 0 assert len(docs[7].captions) == 0 assert len(docs[8].captions) == 0 assert len(docs[9].captions) == 0 assert len(docs[10].captions) == 0 assert len(docs[11].captions) == 0 train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 299 assert session.query(Temp).count() == 147 assert session.query(Volt).count() == 140 assert len(mention_extractor.get_mentions()) == 3 assert len(mention_extractor.get_mentions()[0]) == 299 assert (len( mention_extractor.get_mentions(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 70) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3684 assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 72 assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 448 assert session.query(PartVolt).count() == 4282 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) dev_cands = candidate_extractor.get_candidates(split=1) test_cands = candidate_extractor.get_candidates(split=2) assert len(train_cands) == 2 assert len(train_cands[0]) == 3684 assert (len( candidate_extractor.get_candidates(docs=[ session.query(Document).filter(Document.name == "112823").first() ])[0]) == 1496) # Featurization featurizer = Featurizer(session, [PartTemp, PartVolt]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 225 assert session.query(FeatureKey).count() == 1179 # Test Dropping FeatureKey # Should force a row deletion featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NFP NN NFP]"]) assert session.query(FeatureKey).count() == 1178 # Should only remove the part_volt as a relation and leave part_temp assert set( session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == { "part_temp", "part_volt" } featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) assert session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"] assert session.query(FeatureKey).count() == 1178 # Removing the last relation from a key should delete the row featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp]) assert session.query(FeatureKey).count() == 1177 session.query(Feature).delete() session.query(FeatureKey).delete() featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 6669 assert session.query(FeatureKey).count() == 4161 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (3684, 4161) assert F_train[1].shape == (2985, 4161) assert len(featurizer.get_keys()) == 4161 featurizer.apply(split=1, parallelism=PARALLEL) assert session.query(Feature).count() == 6894 assert session.query(FeatureKey).count() == 4161 F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (72, 4161) assert F_dev[1].shape == (153, 4161) featurizer.apply(split=2, parallelism=PARALLEL) assert session.query(Feature).count() == 8486 assert session.query(FeatureKey).count() == 4161 F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (448, 4161) assert F_test[1].shape == (1144, 4161) gold_file = "tests/data/hardware_tutorial_gold.csv" load_hardware_labels(session, PartTemp, gold_file, ATTRIBUTE, annotator_name="gold") assert session.query(GoldLabel).count() == 4204 load_hardware_labels(session, PartVolt, gold_file, ATTRIBUTE, annotator_name="gold") assert session.query(GoldLabel).count() == 8486 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] ce_v_max_lfs = [ LF_bad_keywords_in_row, LF_current_in_row, LF_non_ce_voltages_in_row, ] labeler = Labeler(session, [PartTemp, PartVolt]) with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply(split=0, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 6669 assert session.query(LabelKey).count() == 9 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3684, 9) assert L_train[1].shape == (2985, 9) assert len(labeler.get_keys()) == 9 L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (3684, 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (3684, 1) gen_model = LabelModel(k=2) gen_model.train_model(L_train[0], n_epochs=500, print_every=100) train_marginals = gen_model.predict_proba(L_train[0])[:, 1] disc_model = LogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 6669 assert session.query(LabelKey).count() == 16 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3684, 16) gen_model = LabelModel(k=2) gen_model.train_model(L_train[0], n_epochs=500, print_every=100) train_marginals = gen_model.predict_proba(L_train[0])[:, 1] disc_model = LogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7 # Testing LSTM disc_model = LSTM() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7 # Testing Sparse Logistic Regression disc_model = SparseLogisticRegression() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7 # Testing Sparse LSTM disc_model = SparseLSTM() disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001) test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))] (TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info("prec: {}".format(prec)) logger.info("rec: {}".format(rec)) logger.info("f1: {}".format(f1)) assert f1 > 0.7
print(f"Sentences: {session.query(Sentence).count()}") train_docs = session.query(Document).order_by(Document.name).all() # Mention from fonduerconfig import mention_classes, mention_spaces, matchers, candidate_classes from fonduer.candidates import MentionExtractor mention_extractor = MentionExtractor( session, mention_classes, mention_spaces, matchers ) from fonduer.candidates.models import Mention mention_extractor.apply(train_docs, parallelism=PARALLEL) #num_names = session.query(Presidentname).count() #num_pobs = session.query(Placeofbirth).count() print( f"Total Mentions: {session.query(Mention).count()}" ) from fonduer.candidates import CandidateExtractor candidate_extractor = CandidateExtractor(session, candidate_classes) candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL) train_cands = candidate_extractor.get_candidates(split=0) print( f"Number of Candidates: {len(train_cands[0])}" )
def main( conn_string, stg_temp_min=False, stg_temp_max=False, polarity=False, ce_v_max=False, max_docs=float("inf"), parse=False, first_time=False, re_label=False, parallel=4, log_dir=None, verbose=False, ): if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) rel_list = [] if stg_temp_min: rel_list.append("stg_temp_min") if stg_temp_max: rel_list.append("stg_temp_max") if polarity: rel_list.append("polarity") if ce_v_max: rel_list.append("ce_v_max") session = Meta.init(conn_string).Session() # Parsing logger.info(f"Starting parsing...") start = timer() docs, train_docs, dev_docs, test_docs = parse_dataset(session, dirname, first_time=parse, parallel=parallel, max_docs=max_docs) end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") # Mention Extraction start = timer() mentions = [] ngrams = [] matchers = [] # Only do those that are enabled Part = mention_subclass("Part") part_matcher = get_matcher("part") part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) mentions.append(Part) ngrams.append(part_ngrams) matchers.append(part_matcher) if stg_temp_min: StgTempMin = mention_subclass("StgTempMin") stg_temp_min_matcher = get_matcher("stg_temp_min") stg_temp_min_ngrams = MentionNgramsTemp(n_max=2) mentions.append(StgTempMin) ngrams.append(stg_temp_min_ngrams) matchers.append(stg_temp_min_matcher) if stg_temp_max: StgTempMax = mention_subclass("StgTempMax") stg_temp_max_matcher = get_matcher("stg_temp_max") stg_temp_max_ngrams = MentionNgramsTemp(n_max=2) mentions.append(StgTempMax) ngrams.append(stg_temp_max_ngrams) matchers.append(stg_temp_max_matcher) if polarity: Polarity = mention_subclass("Polarity") polarity_matcher = get_matcher("polarity") polarity_ngrams = MentionNgrams(n_max=1) mentions.append(Polarity) ngrams.append(polarity_ngrams) matchers.append(polarity_matcher) if ce_v_max: CeVMax = mention_subclass("CeVMax") ce_v_max_matcher = get_matcher("ce_v_max") ce_v_max_ngrams = MentionNgramsVolt(n_max=1) mentions.append(CeVMax) ngrams.append(ce_v_max_ngrams) matchers.append(ce_v_max_matcher) mention_extractor = MentionExtractor(session, mentions, ngrams, matchers) if first_time: mention_extractor.apply(docs, parallelism=parallel) logger.info(f"Total Mentions: {session.query(Mention).count()}") logger.info(f"Total Part: {session.query(Part).count()}") if stg_temp_min: logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}") if stg_temp_max: logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}") if polarity: logger.info(f"Total Polarity: {session.query(Polarity).count()}") if ce_v_max: logger.info(f"Total CeVMax: {session.query(CeVMax).count()}") # Candidate Extraction cands = [] throttlers = [] if stg_temp_min: PartStgTempMin = candidate_subclass("PartStgTempMin", [Part, StgTempMin]) stg_temp_min_throttler = stg_temp_filter cands.append(PartStgTempMin) throttlers.append(stg_temp_min_throttler) if stg_temp_max: PartStgTempMax = candidate_subclass("PartStgTempMax", [Part, StgTempMax]) stg_temp_max_throttler = stg_temp_filter cands.append(PartStgTempMax) throttlers.append(stg_temp_max_throttler) if polarity: PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity]) polarity_throttler = polarity_filter cands.append(PartPolarity) throttlers.append(polarity_throttler) if ce_v_max: PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax]) ce_v_max_throttler = ce_v_max_filter cands.append(PartCeVMax) throttlers.append(ce_v_max_throttler) candidate_extractor = CandidateExtractor(session, cands, throttlers=throttlers) if first_time: for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=parallel) num_cands = session.query(Candidate).filter( Candidate.split == i).count() logger.info(f"Candidates in split={i}: {num_cands}") # These must be sorted for deterministic behavior. train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) end = timer() logger.warning( f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}") logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}") logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}") logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}") pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl") with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) # Check total recall for i, name in enumerate(rel_list): logger.info(name) result = entity_level_scores( candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc), attribute=name, corpus=dev_docs, ) logger.info(f"{name} Total Dev Recall: {result.rec:.3f}") result = entity_level_scores( candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc), attribute=name, corpus=test_docs, ) logger.info(f"{name} Total Test Recall: {result.rec:.3f}") # Featurization start = timer() cands = [] if stg_temp_min: cands.append(PartStgTempMin) if stg_temp_max: cands.append(PartStgTempMax) if polarity: cands.append(PartPolarity) if ce_v_max: cands.append(PartCeVMax) # Using parallelism = 1 for deterministic behavior. featurizer = Featurizer(session, cands, parallelism=1) if first_time: logger.info("Starting featurizer...") featurizer.apply(split=0, train=True) featurizer.apply(split=1) featurizer.apply(split=2) logger.info("Done") logger.info("Getting feature matrices...") if first_time: F_train = featurizer.get_feature_matrices(train_cands) F_dev = featurizer.get_feature_matrices(dev_cands) F_test = featurizer.get_feature_matrices(test_cands) end = timer() logger.warning( f"Featurization Time (min): {((end - start) / 60.0):.1f}") F_train_dict = {} F_dev_dict = {} F_test_dict = {} for idx, relation in enumerate(rel_list): F_train_dict[relation] = F_train[idx] F_dev_dict[relation] = F_dev[idx] F_test_dict[relation] = F_test[idx] pickle.dump(F_train_dict, open(os.path.join(dirname, "F_train_dict.pkl"), "wb")) pickle.dump(F_dev_dict, open(os.path.join(dirname, "F_dev_dict.pkl"), "wb")) pickle.dump(F_test_dict, open(os.path.join(dirname, "F_test_dict.pkl"), "wb")) else: F_train_dict = pickle.load( open(os.path.join(dirname, "F_train_dict.pkl"), "rb")) F_dev_dict = pickle.load( open(os.path.join(dirname, "F_dev_dict.pkl"), "rb")) F_test_dict = pickle.load( open(os.path.join(dirname, "F_test_dict.pkl"), "rb")) F_train = [] F_dev = [] F_test = [] for relation in rel_list: F_train.append(F_train_dict[relation]) F_dev.append(F_dev_dict[relation]) F_test.append(F_test_dict[relation]) logger.info("Done.") for i, cand in enumerate(cands): logger.info(f"{cand} Train shape: {F_train[i].shape}") logger.info(f"{cand} Test shape: {F_test[i].shape}") logger.info(f"{cand} Dev shape: {F_dev[i].shape}") logger.info("Labeling training data...") # Labeling start = timer() lfs = [] if stg_temp_min: lfs.append(stg_temp_min_lfs) if stg_temp_max: lfs.append(stg_temp_max_lfs) if polarity: lfs.append(polarity_lfs) if ce_v_max: lfs.append(ce_v_max_lfs) # Using parallelism = 1 for deterministic behavior. labeler = Labeler(session, cands, parallelism=1) if first_time: logger.info("Applying LFs...") labeler.apply(split=0, lfs=lfs, train=True) logger.info("Done...") # Uncomment if debugging LFs # load_transistor_labels(session, cands, ["ce_v_max"]) # labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel) # labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel) elif re_label: logger.info("Updating LFs...") labeler.update(split=0, lfs=lfs) logger.info("Done...") # Uncomment if debugging LFs # labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel) # labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel) logger.info("Getting label matrices...") L_train = labeler.get_label_matrices(train_cands) # Uncomment if debugging LFs # L_dev = labeler.get_label_matrices(dev_cands) # L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold") # # L_test = labeler.get_label_matrices(test_cands) # L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold") logger.info("Done.") if first_time: marginals_dict = {} for idx, relation in enumerate(rel_list): marginals_dict[relation] = generative_model(L_train[idx]) pickle.dump(marginals_dict, open(os.path.join(dirname, "marginals_dict.pkl"), "wb")) else: marginals_dict = pickle.load( open(os.path.join(dirname, "marginals_dict.pkl"), "rb")) marginals = [] for relation in rel_list: marginals.append(marginals_dict[relation]) end = timer() logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}") start = timer() word_counter = collect_word_counter(train_cands) # Training config config = { "meta_config": { "verbose": True, "seed": 17 }, "model_config": { "model_path": None, "device": 0, "dataparallel": False }, "learner_config": { "n_epochs": 5, "optimizer_config": { "lr": 0.001, "l2": 0.0 }, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": { "model/all/train/loss": "min" }, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.init(log_dir=Meta.log_path, config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=300, specials=specials) train_idxs = [] train_dataloader = [] for idx, relation in enumerate(rel_list): diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1) train_idxs.append(np.where(diffs > 1e-6)[0]) train_dataloader.append( EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset( relation, train_cands[idx], F_train[idx], emb_layer.word2id, marginals[idx], train_idxs[idx], ), split="train", batch_size=100, shuffle=True, )) num_feature_keys = len(featurizer.get_keys()) model = EmmentalModel(name=f"transistor_tasks") # List relation names, arities, list of classes tasks = create_task( rel_list, [2] * len(rel_list), num_feature_keys, [2] * len(rel_list), emb_layer, model="LogisticRegression", ) for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() # If given a list of multi, will train on multiple emmental_learner.learn(model, train_dataloader) # List of dataloader for each rlation for idx, relation in enumerate(rel_list): test_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, test_cands[idx], F_test[idx], emb_layer.word2id, 2), split="test", batch_size=100, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) best_result, best_b = scoring( relation, test_preds, test_cands[idx], test_docs, F_test[idx], parts_by_doc, num=100, ) # Dump CSV files for CE_V_MAX for digi-key analysis if relation == "ce_v_max": dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=100, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv") Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv") # Dump CSV files for POLARITY for digi-key analysis if relation == "polarity": dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=100, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv") Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv") end = timer() logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")
def test_feature_extraction(): """Test extracting candidates from mentions from documents.""" PARALLEL = 1 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgrams(n_max=1) temp_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert docs[0].name == "112823" assert session.query(Part).count() == 58 assert session.query(Temp).count() == 16 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) n_cands = session.query(PartTemp).count() # Featurization based on default feature library featurizer = Featurizer(session, [PartTemp]) # Test that featurization default feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_feats = session.query(FeatureKey).count() featurizer.clear(train=True) # Example feature extractor def feat_ext(candidates): candidates = candidates if isinstance(candidates, list) else [candidates] for candidate in candidates: yield candidate.id, f"cand_id_{candidate.id}", 1 # Featurization with one extra feature extractor feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_w_customized_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization textual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_textual_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization tabular feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_tabular_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization structural feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_structural_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization visual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_visual_features = session.query(FeatureKey).count() featurizer.clear(train=True) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features) assert n_default_w_customized_features == n_default_feats + n_cands
#sents = session.query(Sentence).all() from fonduer.candidates import CandidateExtractor, MentionExtractor, MentionNgrams from fonduer.candidates.models import mention_subclass, candidate_subclass from fonduer.candidates.matchers import RegexMatchSpan, Union, LambdaFunctionMatcher from dataset_utils import price_match # Defining ngrams for candidates extraction_name = 'price' ngrams = MentionNgrams(n_max=5) # Define matchers matchers = LambdaFunctionMatcher(func=price_match) # Getting candidates PriceMention = mention_subclass("PriceMention") mention_extractor = MentionExtractor( session, [PriceMention], [ngrams], [matchers] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=parallelism) candidate_class = candidate_subclass("Price", [PriceMention]) candidate_extractor = CandidateExtractor(session, [candidate_class]) # Applying candidate extractors candidate_extractor.apply(docs, split=0, parallelism=parallelism) print("==============================") print(f"Candidate extraction results for {postgres_db_name}:") print("Number of candidates:", session.query(candidate_class).filter(candidate_class.split == 0).count()) print("==============================")
def test_incremental(caplog): """Run an end-to-end test on incremental additions.""" caplog.set_level(logging.INFO) # SpaCy on mac has issue on parallel parsing if os.name == "posix": logger.info("Using single core.") PARALLEL = 1 else: PARALLEL = 2 # Travis only gives 2 cores max_docs = 1 session = Meta.init("postgres://localhost:5432/" + DB).Session() docs_path = "tests/data/html/dtc114w.html" pdf_path = "tests/data/pdf/dtc114w.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) num_docs = session.query(Document).count() logger.info("Docs: {}".format(num_docs)) assert num_docs == max_docs docs = corpus_parser.get_documents() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor( session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 11 assert session.query(Temp).count() == 9 # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 78 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 78 # Featurization featurizer = Featurizer(session, [PartTemp]) featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 78 assert session.query(FeatureKey).count() == 496 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (78, 496) assert len(featurizer.get_keys()) == 496 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] labeler = Labeler(session, [PartTemp]) labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL) assert session.query(Label).count() == 78 # Only 5 because LF_operating_row doesn't apply to the first test doc assert session.query(LabelKey).count() == 5 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (78, 5) assert len(labeler.get_keys()) == 5 docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False) assert len(corpus_parser.get_documents()) == 2 new_docs = corpus_parser.get_last_documents() assert len(new_docs) == 1 assert new_docs[0].name == "112823" # Get mentions from just the new docs mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False) assert session.query(Part).count() == 81 assert session.query(Temp).count() == 33 # Just run candidate extraction and assign to split 0 candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False) # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0) assert len(train_cands) == 1 assert len(train_cands[0]) == 1574 # Update features featurizer.update(new_docs, parallelism=PARALLEL) assert session.query(Feature).count() == 1574 assert session.query(FeatureKey).count() == 2425 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (1574, 2425) assert len(featurizer.get_keys()) == 2425 # Update Labels labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 1574 assert session.query(LabelKey).count() == 6 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (1574, 6)
def test_e2e(): """Run an end-to-end test on documents of the hardware domain.""" PARALLEL = 4 max_docs = 12 fonduer.init_logging( log_dir="log_folder", format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s", level=logging.INFO, ) session = fonduer.Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, parallelism=PARALLEL, structural=True, lingual=True, visual=True, pdf_path=pdf_path, ) corpus_parser.apply(doc_preprocessor) assert session.query(Document).count() == max_docs num_docs = session.query(Document).count() logger.info(f"Docs: {num_docs}") assert num_docs == max_docs num_sentences = session.query(Sentence).count() logger.info(f"Sentences: {num_sentences}") # Divide into test and train docs = sorted(corpus_parser.get_documents()) last_docs = sorted(corpus_parser.get_last_documents()) ld = len(docs) assert ld == len(last_docs) assert len(docs[0].sentences) == len(last_docs[0].sentences) assert len(docs[0].sentences) == 799 assert len(docs[1].sentences) == 663 assert len(docs[2].sentences) == 784 assert len(docs[3].sentences) == 661 assert len(docs[4].sentences) == 513 assert len(docs[5].sentences) == 700 assert len(docs[6].sentences) == 528 assert len(docs[7].sentences) == 161 assert len(docs[8].sentences) == 228 assert len(docs[9].sentences) == 511 assert len(docs[10].sentences) == 331 assert len(docs[11].sentences) == 528 # Check table numbers assert len(docs[0].tables) == 9 assert len(docs[1].tables) == 9 assert len(docs[2].tables) == 14 assert len(docs[3].tables) == 11 assert len(docs[4].tables) == 11 assert len(docs[5].tables) == 10 assert len(docs[6].tables) == 10 assert len(docs[7].tables) == 2 assert len(docs[8].tables) == 7 assert len(docs[9].tables) == 10 assert len(docs[10].tables) == 6 assert len(docs[11].tables) == 9 # Check figure numbers assert len(docs[0].figures) == 32 assert len(docs[1].figures) == 11 assert len(docs[2].figures) == 38 assert len(docs[3].figures) == 31 assert len(docs[4].figures) == 7 assert len(docs[5].figures) == 38 assert len(docs[6].figures) == 10 assert len(docs[7].figures) == 31 assert len(docs[8].figures) == 4 assert len(docs[9].figures) == 27 assert len(docs[10].figures) == 5 assert len(docs[11].figures) == 27 # Check caption numbers assert len(docs[0].captions) == 0 assert len(docs[1].captions) == 0 assert len(docs[2].captions) == 0 assert len(docs[3].captions) == 0 assert len(docs[4].captions) == 0 assert len(docs[5].captions) == 0 assert len(docs[6].captions) == 0 assert len(docs[7].captions) == 0 assert len(docs[8].captions) == 0 assert len(docs[9].captions) == 0 assert len(docs[10].captions) == 0 assert len(docs[11].captions) == 0 train_docs = set() dev_docs = set() test_docs = set() splits = (0.5, 0.75) data = [(doc.name, doc) for doc in docs] data.sort(key=lambda x: x[0]) for i, (doc_name, doc) in enumerate(data): if i < splits[0] * ld: train_docs.add(doc) elif i < splits[1] * ld: dev_docs.add(doc) else: test_docs.add(doc) logger.info([x.name for x in train_docs]) # NOTE: With multi-relation support, return values of getting candidates, # mentions, or sparse matrices are formatted as a list of lists. This means # that with a single relation, we need to index into the list of lists to # get the candidates/mentions/sparse matrix for a particular relation or # mention. # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") mention_extractor = MentionExtractor( session, [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Part).count() == 299 assert session.query(Temp).count() == 138 assert session.query(Volt).count() == 140 assert len(mention_extractor.get_mentions()) == 3 assert len(mention_extractor.get_mentions()[0]) == 299 assert ( len( mention_extractor.get_mentions( docs=[session.query(Document).filter(Document.name == "112823").first()] )[0] ) == 70 ) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler] ) for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=PARALLEL) assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493 assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61 assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416 assert session.query(PartVolt).count() == 4282 # Grab candidate lists train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) assert len(train_cands) == 2 assert len(train_cands[0]) == 3493 assert ( len( candidate_extractor.get_candidates( docs=[session.query(Document).filter(Document.name == "112823").first()] )[0] ) == 1432 ) # Featurization featurizer = Featurizer(session, [PartTemp, PartVolt]) # Test that FeatureKey is properly reset featurizer.apply(split=1, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 214 assert session.query(FeatureKey).count() == 1260 # Test Dropping FeatureKey # Should force a row deletion featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"]) assert session.query(FeatureKey).count() == 1259 # Should only remove the part_volt as a relation and leave part_temp assert set( session.query(FeatureKey) .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]") .one() .candidate_classes ) == {"part_temp", "part_volt"} featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) assert session.query(FeatureKey).filter( FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]" ).one().candidate_classes == ["part_temp"] assert session.query(FeatureKey).count() == 1259 # Inserting the removed key featurizer.upsert_keys( ["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt] ) assert set( session.query(FeatureKey) .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]") .one() .candidate_classes ) == {"part_temp", "part_volt"} assert session.query(FeatureKey).count() == 1259 # Removing the key again featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt]) # Removing the last relation from a key should delete the row featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp]) assert session.query(FeatureKey).count() == 1258 session.query(Feature).delete(synchronize_session="fetch") session.query(FeatureKey).delete(synchronize_session="fetch") featurizer.apply(split=0, train=True, parallelism=PARALLEL) assert session.query(Feature).count() == 6478 assert session.query(FeatureKey).count() == 4538 F_train = featurizer.get_feature_matrices(train_cands) assert F_train[0].shape == (3493, 4538) assert F_train[1].shape == (2985, 4538) assert len(featurizer.get_keys()) == 4538 featurizer.apply(split=1, parallelism=PARALLEL) assert session.query(Feature).count() == 6692 assert session.query(FeatureKey).count() == 4538 F_dev = featurizer.get_feature_matrices(dev_cands) assert F_dev[0].shape == (61, 4538) assert F_dev[1].shape == (153, 4538) featurizer.apply(split=2, parallelism=PARALLEL) assert session.query(Feature).count() == 8252 assert session.query(FeatureKey).count() == 4538 F_test = featurizer.get_feature_matrices(test_cands) assert F_test[0].shape == (416, 4538) assert F_test[1].shape == (1144, 4538) gold_file = "tests/data/hardware_tutorial_gold.csv" labeler = Labeler(session, [PartTemp, PartVolt]) labeler.apply( docs=last_docs, lfs=[[gold], [gold]], table=GoldLabel, train=True, parallelism=PARALLEL, ) assert session.query(GoldLabel).count() == 8252 stg_temp_lfs = [ LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row, LF_to_left, LF_negative_number_left, ] ce_v_max_lfs = [ LF_bad_keywords_in_row, LF_current_in_row, LF_non_ce_voltages_in_row, ] with pytest.raises(ValueError): labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL) labeler.apply( docs=train_docs, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL, ) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 9 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 9) assert L_train[1].shape == (2985, 9) assert len(labeler.get_keys()) == 9 # Test Dropping LabelerKey labeler.drop_keys(["LF_storage_row"]) assert len(labeler.get_keys()) == 8 # Test Upserting LabelerKey labeler.upsert_keys(["LF_storage_row"]) assert "LF_storage_row" in [label.name for label in labeler.get_keys()] L_train_gold = labeler.get_gold_labels(train_cands) assert L_train_gold[0].shape == (3493, 1) L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold") assert L_train_gold[0].shape == (3493, 1) gen_model = LabelModel() gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100) train_marginals = gen_model.predict_proba(L_train[0]) disc_model = LogisticRegression() disc_model.train( (train_cands[0], F_train[0]), train_marginals, X_dev=(train_cands[0], F_train[0]), Y_dev=L_train_gold[0].reshape(-1), b=0.6, pos_label=TRUE, n_epochs=5, lr=0.001, ) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))] pickle_file = "tests/data/parts_by_doc_dict.pkl" with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 < 0.7 and f1 > 0.3 stg_temp_lfs_2 = [ LF_to_left, LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned, LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row, LF_complement_left_row, LF_too_many_numbers_row, LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant, ] labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL) assert session.query(Label).count() == 6478 assert session.query(LabelKey).count() == 16 L_train = labeler.get_label_matrices(train_cands) assert L_train[0].shape == (3493, 16) gen_model = LabelModel() gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100) train_marginals = gen_model.predict_proba(L_train[0]) disc_model = LogisticRegression() disc_model.train( (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001 ) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))] (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing LSTM disc_model = LSTM() disc_model.train( (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001 ) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))] (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing Sparse Logistic Regression disc_model = SparseLogisticRegression() disc_model.train( (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001 ) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))] (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Testing Sparse LSTM disc_model = SparseLSTM() disc_model.train( (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001 ) test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE) true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))] (TP, FP, FN) = entity_level_f1( true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc ) tp_len = len(TP) fp_len = len(FP) fn_len = len(FN) prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan") rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan") f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan") logger.info(f"prec: {prec}") logger.info(f"rec: {rec}") logger.info(f"f1: {f1}") assert f1 > 0.7 # Evaluate mention level scores L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold") Y_test = L_test_gold[0].reshape(-1) scores = disc_model.score((test_cands[0], F_test[0]), Y_test, b=0.6, pos_label=TRUE) logger.info(scores) assert scores["f1"] > 0.6
def test_too_many_clients_error_should_not_happen(): """Too many clients error should not happens.""" PARALLEL = 32 logger.info("Parallel: {PARALLEL}") def do_nothing_matcher(fig): return True max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser( session, structural=True, lingual=True, visual=True, pdf_path=pdf_path ) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) mention_extractor = MentionExtractor( session, [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) mention_extractor.apply(docs, parallelism=PARALLEL) # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) # Test that no throttler in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt] ) # Pass, no throttler candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) candidate_extractor.clear_all(split=0) # Test with None in throttlers in candidate extractor candidate_extractor = CandidateExtractor( session, [PartTemp, PartVolt], throttlers=[temp_throttler, None] ) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)