def test_candidate_with_nullable_mentions(): """Test if mentions can be NULL.""" docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/" doc = parse_doc(docs_path, "112823", pdf_path) # Mention Extraction MentionTemp = mention_subclass("MentionTemp") temp_ngrams = MentionNgramsTemp(n_max=2) mention_extractor_udf = MentionExtractorUDF( [MentionTemp], [temp_ngrams], [temp_matcher], ) doc = mention_extractor_udf.apply(doc) assert len(doc.mention_temps) == 23 # Candidate Extraction CandidateTemp = candidate_subclass("CandidateTemp", [MentionTemp], nullables=[True]) candidate_extractor_udf = CandidateExtractorUDF([CandidateTemp], [None], False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) # The number of extracted candidates should be that of mentions + 1 (NULL) assert len(doc.candidate_temps) == len(doc.mention_temps) + 1 # Extracted candidates should include one with NULL mention. assert None in [c[0] for c in doc.candidate_temps]
def test_visualizer(): from fonduer.utils.visualizer import Visualizer # noqa """Unit test of visualizer using the md document. """ docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/md.pdf" # Grab the md document doc = parse_doc(docs_path, "md", pdf_path) assert doc.name == "md" organization_ngrams = MentionNgrams(n_max=1) Org = mention_subclass("Org") organization_matcher = OrganizationMatcher() mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams], [organization_matcher]) doc = mention_extractor_udf.apply(doc) Organization = candidate_subclass("Organization", [Org]) candidate_extractor_udf = CandidateExtractorUDF([Organization], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) cands = doc.organizations # Test visualizer pdf_path = "tests/data/pdf_simple" vis = Visualizer(pdf_path) vis.display_candidates([cands[0]])
def test_visualizer(): """Unit test of visualizer using the md document.""" from fonduer.utils.visualizer import Visualizer, get_box # noqa docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/" # Grab the md document doc = parse_doc(docs_path, "md", pdf_path) assert doc.name == "md" organization_ngrams = MentionNgrams(n_max=1) Org = mention_subclass("Org") organization_matcher = OrganizationMatcher() mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams], [organization_matcher]) doc = mention_extractor_udf.apply(doc) Organization = candidate_subclass("Organization", [Org]) candidate_extractor_udf = CandidateExtractorUDF([Organization], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) # Take one candidate cand = doc.organizations[0] pdf_path = "tests/data/pdf_simple" vis = Visualizer(pdf_path) # Test bounding boxes boxes = [get_box(mention.context) for mention in cand.get_mentions()] for box in boxes: assert box.top <= box.bottom assert box.left <= box.right assert boxes == [ mention.context.get_bbox() for mention in cand.get_mentions() ] # Test visualizer vis.display_candidates([cand])
def filter_candidate(self, document): document = CandidateExtractorUDF([self.Email_C], throttlers=self.get_throttler(), self_relations=False, nested_relations=False, symmetric_relations=False).apply( document, split=0) return document
def _load_pyfunc(model_path: str) -> Any: """Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.""" # Load mention_classes _load_mention_classes(model_path) # Load candiate_classes _load_candidate_classes(model_path) # Load a pickled model model = pickle.load(open(os.path.join(model_path, "model.pkl"), "rb")) fonduer_model = model["fonduer_model"] fonduer_model.preprocessor = model["preprosessor"] fonduer_model.parser = ParserUDF(**model["parser"]) fonduer_model.mention_extractor = MentionExtractorUDF( **model["mention_extractor"]) fonduer_model.candidate_extractor = CandidateExtractorUDF( **model["candidate_extractor"]) # Configure logging for Fonduer init_logging(log_dir="logs") pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) candidate_classes = fonduer_model.candidate_extractor.candidate_classes fonduer_model.model_type = pyfunc_conf.get(MODEL_TYPE, "emmental") if fonduer_model.model_type == "emmental": emmental.init() fonduer_model.featurizer = FeaturizerUDF(candidate_classes, FeatureExtractor()) fonduer_model.key_names = model["feature_keys"] fonduer_model.word2id = model["word2id"] # Load the emmental_model buffer = BytesIO() buffer.write(model["emmental_model"]) buffer.seek(0) fonduer_model.emmental_model = torch.load(buffer) else: fonduer_model.labeler = LabelerUDF(candidate_classes) fonduer_model.key_names = model["labeler_keys"] fonduer_model.lfs = model["lfs"] fonduer_model.label_models = [] for state_dict in model["label_models_state_dict"]: label_model = LabelModel() label_model.__dict__.update(state_dict) fonduer_model.label_models.append(label_model) return fonduer_model
def test_multinary_relation_feature_extraction(): """Test extracting candidates from mentions from documents.""" docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" # Parsing doc = parse_doc(docs_path, "112823", pdf_path) assert len(doc.sentences) == 799 # Mention Extraction part_ngrams = MentionNgrams(n_max=1) temp_ngrams = MentionNgrams(n_max=1) volt_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") mention_extractor_udf = MentionExtractorUDF( [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) doc = mention_extractor_udf.apply(doc) assert len(doc.parts) == 62 assert len(doc.temps) == 16 assert len(doc.volts) == 33 part = doc.parts[0] temp = doc.temps[0] volt = doc.volts[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") logger.info(f"Volt: {volt.context}") # Candidate Extraction PartTempVolt = candidate_subclass("PartTempVolt", [Part, Temp, Volt]) candidate_extractor_udf = CandidateExtractorUDF([PartTempVolt], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) # Manually set id as it is not set automatically b/c a database is not used. i = 0 for cand in doc.part_temp_volts: cand.id = i i = i + 1 n_cands = len(doc.part_temp_volts) # Featurization based on default feature library featurizer_udf = FeaturizerUDF([PartTempVolt], FeatureExtractor()) # Test that featurization default feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_default_feats = len(key_set) # Example feature extractor def feat_ext(candidates): candidates = candidates if isinstance(candidates, list) else [candidates] for candidate in candidates: yield candidate.id, f"cand_id_{candidate.id}", 1 # Featurization with one extra feature extractor feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_default_w_customized_features = len(key_set) # Example spurious feature extractor def bad_feat_ext(candidates): raise RuntimeError() # Featurization with a spurious feature extractor feature_extractors = FeatureExtractor( customize_feature_funcs=[bad_feat_ext]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor logger.info("Featurizing with a spurious feature extractor...") with pytest.raises(RuntimeError): features = featurizer_udf.apply(doc) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization textual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_textual_features = len(key_set) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization tabular feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_tabular_features = len(key_set) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization structural feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_structural_features = len(key_set) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization visual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_visual_features = len(key_set) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features) assert n_default_w_customized_features == n_default_feats + n_cands
def test_unary_relation_feature_extraction(): """Test extracting unary candidates from mentions from documents.""" docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" # Parsing doc = parse_doc(docs_path, "112823", pdf_path) assert len(doc.sentences) == 799 # Mention Extraction part_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") mention_extractor_udf = MentionExtractorUDF([Part], [part_ngrams], [part_matcher]) doc = mention_extractor_udf.apply(doc) assert doc.name == "112823" assert len(doc.parts) == 62 part = doc.parts[0] logger.info(f"Part: {part.context}") # Candidate Extraction PartRel = candidate_subclass("PartRel", [Part]) candidate_extractor_udf = CandidateExtractorUDF([PartRel], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) # Featurization based on default feature library featurizer_udf = FeaturizerUDF([PartRel], FeatureExtractor()) # Test that featurization default feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_default_feats = len(key_set) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization textual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_textual_features = len(key_set) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization tabular feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_tabular_features = len(key_set) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization structural feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_structural_features = len(key_set) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization visual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_visual_features = len(key_set) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features)
def test_multimodal_cand(): """Test multimodal candidate generation""" file_name = "radiology" docs_path = f"tests/data/pure_html/{file_name}.html" doc = parse_doc(docs_path, file_name) assert len(doc.sentences) == 35 # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor_udf = MentionExtractorUDF(ms, m, matchers) doc = mention_extractor_udf.apply(doc) assert len(doc.m_docs) == 1 assert len(doc.m_caps) == 2 assert len(doc.m_secs) == 5 assert len(doc.m_tabs) == 2 assert len(doc.m_figs) == 2 assert len(doc.m_paras) == 30 assert len(doc.m_sents) == 35 assert len(doc.m_cells) == 21 # Candidate Extraction cs_doc = candidate_subclass("cs_doc", [ms_doc]) cs_sec = candidate_subclass("cs_sec", [ms_sec]) cs_tab = candidate_subclass("cs_tab", [ms_tab]) cs_fig = candidate_subclass("cs_fig", [ms_fig]) cs_cell = candidate_subclass("cs_cell", [ms_cell]) cs_para = candidate_subclass("cs_para", [ms_para]) cs_cap = candidate_subclass("cs_cap", [ms_cap]) cs_sent = candidate_subclass("cs_sent", [ms_sent]) candidate_extractor_udf = CandidateExtractorUDF( [cs_doc, cs_sec, cs_tab, cs_fig, cs_cell, cs_para, cs_cap, cs_sent], [None, None, None, None, None, None, None, None], False, False, True, ) doc = candidate_extractor_udf.apply(doc, split=0) assert len(doc.cs_docs) == 1 assert len(doc.cs_caps) == 2 assert len(doc.cs_secs) == 5 assert len(doc.cs_tabs) == 2 assert len(doc.cs_figs) == 2 assert len(doc.cs_paras) == 30 assert len(doc.cs_sents) == 35 assert len(doc.cs_cells) == 21
def test_cand_gen(): """Test extracting candidates from mentions from documents.""" def do_nothing_matcher(fig): return True docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" doc = parse_doc(docs_path, "112823", pdf_path) # Mention Extraction part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) temp_ngrams = MentionNgramsTemp(n_max=2) volt_ngrams = MentionNgramsVolt(n_max=1) figs = MentionFigures(types="png") Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") Fig = mention_subclass("Fig") fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher) with pytest.raises(ValueError): MentionExtractor( "dummy", [Part, Temp, Volt], [part_ngrams, volt_ngrams], # Fail, mismatched arity [part_matcher, temp_matcher, volt_matcher], ) with pytest.raises(ValueError): MentionExtractor( "dummy", [Part, Temp, Volt], [part_ngrams, temp_matcher, volt_ngrams], [part_matcher, temp_matcher], # Fail, mismatched arity ) mention_extractor_udf = MentionExtractorUDF( [Part, Temp, Volt, Fig], [part_ngrams, temp_ngrams, volt_ngrams, figs], [part_matcher, temp_matcher, volt_matcher, fig_matcher], ) doc = mention_extractor_udf.apply(doc) assert len(doc.parts) == 70 assert len(doc.volts) == 33 assert len(doc.temps) == 23 assert len(doc.figs) == 31 part = doc.parts[0] volt = doc.volts[0] temp = doc.temps[0] logger.info(f"Part: {part.context}") logger.info(f"Volt: {volt.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) PartVolt = candidate_subclass("PartVolt", [Part, Volt]) with pytest.raises(ValueError): CandidateExtractor( "dummy", [PartTemp, PartVolt], throttlers=[ temp_throttler, volt_throttler, volt_throttler, ], # Fail, mismatched arity ) with pytest.raises(ValueError): CandidateExtractor( "dummy", [PartTemp], # Fail, mismatched arity throttlers=[temp_throttler, volt_throttler], ) # Test that no throttler in candidate extractor candidate_extractor_udf = CandidateExtractorUDF( [PartTemp, PartVolt], [None, None], False, False, True # Pass, no throttler ) doc = candidate_extractor_udf.apply(doc, split=0) assert len(doc.part_temps) == 1610 assert len(doc.part_volts) == 2310 # Clear doc.part_temps = [] doc.part_volts = [] # Test with None in throttlers in candidate extractor candidate_extractor_udf = CandidateExtractorUDF( [PartTemp, PartVolt], [temp_throttler, None], False, False, True ) doc = candidate_extractor_udf.apply(doc, split=0) assert len(doc.part_temps) == 1432 assert len(doc.part_volts) == 2310 # Clear doc.part_temps = [] doc.part_volts = [] candidate_extractor_udf = CandidateExtractorUDF( [PartTemp, PartVolt], [temp_throttler, volt_throttler], False, False, True ) doc = candidate_extractor_udf.apply(doc, split=0) assert len(doc.part_temps) == 1432 assert len(doc.part_volts) == 1993 assert len(doc.parts) == 70 assert len(doc.volts) == 33 assert len(doc.temps) == 23