def test_candidate_with_nullable_mentions():
    """Test if mentions can be NULL."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/"
    doc = parse_doc(docs_path, "112823", pdf_path)

    # Mention Extraction
    MentionTemp = mention_subclass("MentionTemp")
    temp_ngrams = MentionNgramsTemp(n_max=2)
    mention_extractor_udf = MentionExtractorUDF(
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.mention_temps) == 23

    # Candidate Extraction
    CandidateTemp = candidate_subclass("CandidateTemp", [MentionTemp],
    candidate_extractor_udf = CandidateExtractorUDF([CandidateTemp], [None],
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)
    # The number of extracted candidates should be that of mentions + 1 (NULL)
    assert len(doc.candidate_temps) == len(doc.mention_temps) + 1
    # Extracted candidates should include one with NULL mention.
    assert None in [c[0] for c in doc.candidate_temps]
def test_visualizer():
    """Unit test of visualizer using the md document."""
    from fonduer.utils.visualizer import Visualizer, get_box  # noqa

    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Grab the md document
    doc = parse_doc(docs_path, "md", pdf_path)
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams],

    doc = mention_extractor_udf.apply(doc)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor_udf = CandidateExtractorUDF([Organization], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Take one candidate
    cand = doc.organizations[0]

    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)

    # Test bounding boxes
    boxes = [get_box(mention.context) for mention in cand.get_mentions()]
    for box in boxes:
        assert box.top <= box.bottom
        assert box.left <= box.right
    assert boxes == [
        mention.context.get_bbox() for mention in cand.get_mentions()

    # Test visualizer
    def filter_candidate(self, document):
        document = CandidateExtractorUDF([self.Email_C],
                                             document, split=0)

        return document
def _load_pyfunc(model_path: str) -> Any:
    """Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``."""

    # Load mention_classes
    # Load candiate_classes
    # Load a pickled model
    model = pickle.load(open(os.path.join(model_path, "model.pkl"), "rb"))
    fonduer_model = model["fonduer_model"]
    fonduer_model.preprocessor = model["preprosessor"]
    fonduer_model.parser = ParserUDF(**model["parser"])
    fonduer_model.mention_extractor = MentionExtractorUDF(
    fonduer_model.candidate_extractor = CandidateExtractorUDF(

    # Configure logging for Fonduer

    pyfunc_conf = _get_flavor_configuration(model_path=model_path,
    candidate_classes = fonduer_model.candidate_extractor.candidate_classes

    fonduer_model.model_type = pyfunc_conf.get(MODEL_TYPE, "emmental")
    if fonduer_model.model_type == "emmental":
        fonduer_model.featurizer = FeaturizerUDF(candidate_classes,
        fonduer_model.key_names = model["feature_keys"]
        fonduer_model.word2id = model["word2id"]

        # Load the emmental_model
        buffer = BytesIO()
        fonduer_model.emmental_model = torch.load(buffer)
        fonduer_model.labeler = LabelerUDF(candidate_classes)
        fonduer_model.key_names = model["labeler_keys"]

        fonduer_model.lfs = model["lfs"]

        fonduer_model.label_models = []
        for state_dict in model["label_models_state_dict"]:
            label_model = LabelModel()
    return fonduer_model
def test_multinary_relation_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)
    volt_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor_udf = MentionExtractorUDF(
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.parts) == 62
    assert len(doc.temps) == 16
    assert len(doc.volts) == 33
    part = doc.parts[0]
    temp = doc.temps[0]
    volt = doc.volts[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")
    logger.info(f"Volt: {volt.context}")

    # Candidate Extraction
    PartTempVolt = candidate_subclass("PartTempVolt", [Part, Temp, Volt])

    candidate_extractor_udf = CandidateExtractorUDF([PartTempVolt], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Manually set id as it is not set automatically b/c a database is not used.
    i = 0
    for cand in doc.part_temp_volts:
        cand.id = i
        i = i + 1

    n_cands = len(doc.part_temp_volts)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartTempVolt], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer_udf = FeaturizerUDF([PartTempVolt],

    # Test that featurization default feature library with one extra feature extractor
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_w_customized_features = len(key_set)

    # Example spurious feature extractor
    def bad_feat_ext(candidates):
        raise RuntimeError()

    # Featurization with a spurious feature extractor
    feature_extractors = FeatureExtractor(
    featurizer_udf = FeaturizerUDF([PartTempVolt],

    # Test that featurization default feature library with one extra feature extractor
    logger.info("Featurizing with a spurious feature extractor...")
    with pytest.raises(RuntimeError):
        features = featurizer_udf.apply(doc)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands
def test_unary_relation_feature_extraction():
    """Test extracting unary candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")

    mention_extractor_udf = MentionExtractorUDF([Part], [part_ngrams],
    doc = mention_extractor_udf.apply(doc)

    assert doc.name == "112823"
    assert len(doc.parts) == 62
    part = doc.parts[0]
    logger.info(f"Part: {part.context}")

    # Candidate Extraction
    PartRel = candidate_subclass("PartRel", [Part])

    candidate_extractor_udf = CandidateExtractorUDF([PartRel], None, False,
                                                    False, True)
    doc = candidate_extractor_udf.apply(doc, split=0)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartRel], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartRel],

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartRel],

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartRel],

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartRel],

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)
def test_multimodal_cand():
    """Test multimodal candidate generation"""
    file_name = "radiology"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    assert len(doc.sentences) == 35

    # Mention Extraction

    ms_doc = mention_subclass("m_doc")
    ms_sec = mention_subclass("m_sec")
    ms_tab = mention_subclass("m_tab")
    ms_fig = mention_subclass("m_fig")
    ms_cell = mention_subclass("m_cell")
    ms_para = mention_subclass("m_para")
    ms_cap = mention_subclass("m_cap")
    ms_sent = mention_subclass("m_sent")

    m_doc = MentionDocuments()
    m_sec = MentionSections()
    m_tab = MentionTables()
    m_fig = MentionFigures()
    m_cell = MentionCells()
    m_para = MentionParagraphs()
    m_cap = MentionCaptions()
    m_sent = MentionSentences()

    ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell]
    m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell]
    matchers = [DoNothingMatcher()] * 8

    mention_extractor_udf = MentionExtractorUDF(ms, m, matchers)

    doc = mention_extractor_udf.apply(doc)

    assert len(doc.m_docs) == 1
    assert len(doc.m_caps) == 2
    assert len(doc.m_secs) == 5
    assert len(doc.m_tabs) == 2
    assert len(doc.m_figs) == 2
    assert len(doc.m_paras) == 30
    assert len(doc.m_sents) == 35
    assert len(doc.m_cells) == 21

    # Candidate Extraction
    cs_doc = candidate_subclass("cs_doc", [ms_doc])
    cs_sec = candidate_subclass("cs_sec", [ms_sec])
    cs_tab = candidate_subclass("cs_tab", [ms_tab])
    cs_fig = candidate_subclass("cs_fig", [ms_fig])
    cs_cell = candidate_subclass("cs_cell", [ms_cell])
    cs_para = candidate_subclass("cs_para", [ms_para])
    cs_cap = candidate_subclass("cs_cap", [ms_cap])
    cs_sent = candidate_subclass("cs_sent", [ms_sent])

    candidate_extractor_udf = CandidateExtractorUDF(
        [cs_doc, cs_sec, cs_tab, cs_fig, cs_cell, cs_para, cs_cap, cs_sent],
        [None, None, None, None, None, None, None, None],

    doc = candidate_extractor_udf.apply(doc, split=0)

    assert len(doc.cs_docs) == 1
    assert len(doc.cs_caps) == 2
    assert len(doc.cs_secs) == 5
    assert len(doc.cs_tabs) == 2
    assert len(doc.cs_figs) == 2
    assert len(doc.cs_paras) == 30
    assert len(doc.cs_sents) == 35
    assert len(doc.cs_cells) == 21
def test_cand_gen():
    """Test extracting candidates from mentions from documents."""

    def do_nothing_matcher(fig):
        return True

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"
    doc = parse_doc(docs_path, "112823", pdf_path)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
    with pytest.raises(ValueError):
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity

    mention_extractor_udf = MentionExtractorUDF(
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.parts) == 70
    assert len(doc.volts) == 33
    assert len(doc.temps) == 23
    assert len(doc.figs) == 31
    part = doc.parts[0]
    volt = doc.volts[0]
    temp = doc.temps[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Volt: {volt.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
            [PartTemp, PartVolt],
            ],  # Fail, mismatched arity

    with pytest.raises(ValueError):
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],

    # Test that no throttler in candidate extractor
    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [None, None], False, False, True  # Pass, no throttler

    doc = candidate_extractor_udf.apply(doc, split=0)

    assert len(doc.part_temps) == 1610
    assert len(doc.part_volts) == 2310

    # Clear
    doc.part_temps = []
    doc.part_volts = []

    # Test with None in throttlers in candidate extractor
    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [temp_throttler, None], False, False, True

    doc = candidate_extractor_udf.apply(doc, split=0)
    assert len(doc.part_temps) == 1432
    assert len(doc.part_volts) == 2310

    # Clear
    doc.part_temps = []
    doc.part_volts = []

    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [temp_throttler, volt_throttler], False, False, True

    doc = candidate_extractor_udf.apply(doc, split=0)

    assert len(doc.part_temps) == 1432
    assert len(doc.part_volts) == 1993
    assert len(doc.parts) == 70
    assert len(doc.volts) == 33
    assert len(doc.temps) == 23