コード例 #1
0
def test_row_col_ngram_extraction():
    """Test whether row/column ngrams list is empty, if mention is not in a table."""
    file_name = "lincoln_short"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    # Mention Extraction
    place_ngrams = MentionNgramsTemp(n_max=4)
    Place = mention_subclass("Place")

    def get_row_and_column_ngrams(mention):
        row_ngrams = list(get_row_ngrams(mention))
        col_ngrams = list(get_col_ngrams(mention))
        if not mention.sentence.is_tabular():
            assert len(row_ngrams) == 1 and row_ngrams[0] is None
            assert len(col_ngrams) == 1 and col_ngrams[0] is None
        else:
            assert not any(x is None for x in row_ngrams)
            assert not any(x is None for x in col_ngrams)
        if "birth_place" in row_ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(func=get_row_and_column_ngrams)
    mention_extractor_udf = MentionExtractorUDF(
        [Place], [place_ngrams], [birthplace_matcher]
    )

    doc = mention_extractor_udf.apply(doc)
コード例 #2
0
def test_candidate_with_nullable_mentions():
    """Test if mentions can be NULL."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/"
    doc = parse_doc(docs_path, "112823", pdf_path)

    # Mention Extraction
    MentionTemp = mention_subclass("MentionTemp")
    temp_ngrams = MentionNgramsTemp(n_max=2)
    mention_extractor_udf = MentionExtractorUDF(
        [MentionTemp],
        [temp_ngrams],
        [temp_matcher],
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.mention_temps) == 23

    # Candidate Extraction
    CandidateTemp = candidate_subclass("CandidateTemp", [MentionTemp],
                                       nullables=[True])
    candidate_extractor_udf = CandidateExtractorUDF([CandidateTemp], [None],
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)
    # The number of extracted candidates should be that of mentions + 1 (NULL)
    assert len(doc.candidate_temps) == len(doc.mention_temps) + 1
    # Extracted candidates should include one with NULL mention.
    assert None in [c[0] for c in doc.candidate_temps]
コード例 #3
0
def test_visualizer():
    from fonduer.utils.visualizer import Visualizer  # noqa
    """Unit test of visualizer using the md document.
    """
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/md.pdf"

    # Grab the md document
    doc = parse_doc(docs_path, "md", pdf_path)
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams],
                                                [organization_matcher])

    doc = mention_extractor_udf.apply(doc)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor_udf = CandidateExtractorUDF([Organization], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    cands = doc.organizations

    # Test visualizer
    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)
    vis.display_candidates([cands[0]])
コード例 #4
0
def test_visualizer():
    """Unit test of visualizer using the md document."""
    from fonduer.utils.visualizer import Visualizer, get_box  # noqa

    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Grab the md document
    doc = parse_doc(docs_path, "md", pdf_path)
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams],
                                                [organization_matcher])

    doc = mention_extractor_udf.apply(doc)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor_udf = CandidateExtractorUDF([Organization], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Take one candidate
    cand = doc.organizations[0]

    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)

    # Test bounding boxes
    boxes = [get_box(mention.context) for mention in cand.get_mentions()]
    for box in boxes:
        assert box.top <= box.bottom
        assert box.left <= box.right
    assert boxes == [
        mention.context.get_bbox() for mention in cand.get_mentions()
    ]

    # Test visualizer
    vis.display_candidates([cand])
コード例 #5
0
def email_extract_server(document, email_subclass):
    email_matcher = LambdaFunctionMatcher(func=email_mc,
                                          longest_match_only=True)
    email_space = MentionEmails()

    document = MentionExtractorUDF([email_subclass], [email_space],
                                   [email_matcher]).apply(document)
    return document
コード例 #6
0
def phone_extract_server(document, phone_subclass):
    phone_lambda_matcher = LambdaFunctionMatcher(func=matcher_number_phone)
    regex_matcher = LambdaFunctionMatcher(func=regexMatch)
    phone_lamda_matcher = Union(regex_matcher, phone_lambda_matcher)

    phone_space = MentionPhoneNumber()
    
    document = MentionExtractorUDF([phone_subclass], [phone_space], [phone_lamda_matcher]).apply(document)
    return document
コード例 #7
0
ファイル: test_candidates.py プロジェクト: gadgetlabs/fonduer
def test_multimodal_cand():
    """Test multimodal candidate generation"""
    file_name = "radiology"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    assert len(doc.sentences) == 35

    # Mention Extraction

    ms_doc = mention_subclass("m_doc")
    ms_sec = mention_subclass("m_sec")
    ms_tab = mention_subclass("m_tab")
    ms_fig = mention_subclass("m_fig")
    ms_cell = mention_subclass("m_cell")
    ms_para = mention_subclass("m_para")
    ms_cap = mention_subclass("m_cap")
    ms_sent = mention_subclass("m_sent")

    m_doc = MentionDocuments()
    m_sec = MentionSections()
    m_tab = MentionTables()
    m_fig = MentionFigures()
    m_cell = MentionCells()
    m_para = MentionParagraphs()
    m_cap = MentionCaptions()
    m_sent = MentionSentences()

    ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell]
    m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell]
    matchers = [DoNothingMatcher()] * 8

    mention_extractor_udf = MentionExtractorUDF(ms, m, matchers)

    doc = mention_extractor_udf.apply(doc)

    assert len(doc.m_docs) == 1
    assert len(doc.m_caps) == 2
    assert len(doc.m_secs) == 5
    assert len(doc.m_tabs) == 2
    assert len(doc.m_figs) == 2
    assert len(doc.m_paras) == 30
    assert len(doc.m_sents) == 35
    assert len(doc.m_cells) == 21
コード例 #8
0
def birthday_extract_server(document, birthday_subclass):
    filter_birthday_matcher = LambdaFunctionMatcher(func=filter_birthday,
                                                    longest_match_only=True)
    birthday_conditions_matcher = LambdaFunctionMatcher(
        func=birthday_conditions, longest_match_only=True)
    birthday_matcher = Intersect(filter_birthday_matcher,
                                 birthday_conditions_matcher)
    birthday_space = MentionDates()

    document = MentionExtractorUDF([birthday_subclass], [birthday_space],
                                   [birthday_matcher]).apply(document)
    return document
コード例 #9
0
def address_extract_server(document, address_subclass):
    address_m1 = LambdaFunctionMatcher(func = has_province_address)
    address_m2 = LambdaFunctionMatcher(func = has_geographic_term_address)
    address_m3 = LambdaFunctionMatcher(func = address_prefix)
    address_m4 = LambdaFunctionMatcher(func = is_collection_of_number_and_geographical_term_and_provinces_name_address)
    address_m5 = LambdaFunctionMatcher(func = hasnt_ignor_words)
    address_matcher = Intersect(Union(address_m1, address_m2, address_m3), address_m4, address_m5)

    address_space = MentionSentences()
    
    document = MentionExtractorUDF([address_subclass], [address_space], [address_matcher]).apply(document)
    return document
コード例 #10
0
def test_ngrams():
    """Test ngram limits in mention extraction"""
    file_name = "lincoln_short"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    # Mention Extraction
    Person = mention_subclass("Person")
    person_ngrams = MentionNgrams(n_max=3)
    person_matcher = PersonMatcher()

    mention_extractor_udf = MentionExtractorUDF(
        [Person], [person_ngrams], [person_matcher]
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.persons) == 118
    mentions = doc.persons
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0

    # Test for unigram exclusion
    for mention in doc.persons[:]:
        doc.persons.remove(mention)
    assert len(doc.persons) == 0

    person_ngrams = MentionNgrams(n_min=2, n_max=3)
    mention_extractor_udf = MentionExtractorUDF(
        [Person], [person_ngrams], [person_matcher]
    )
    doc = mention_extractor_udf.apply(doc)
    assert len(doc.persons) == 69
    mentions = doc.persons
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
コード例 #11
0
def _load_pyfunc(model_path: str) -> Any:
    """Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``."""

    # Load mention_classes
    _load_mention_classes(model_path)
    # Load candiate_classes
    _load_candidate_classes(model_path)
    # Load a pickled model
    model = pickle.load(open(os.path.join(model_path, "model.pkl"), "rb"))
    fonduer_model = model["fonduer_model"]
    fonduer_model.preprocessor = model["preprosessor"]
    fonduer_model.parser = ParserUDF(**model["parser"])
    fonduer_model.mention_extractor = MentionExtractorUDF(
        **model["mention_extractor"])
    fonduer_model.candidate_extractor = CandidateExtractorUDF(
        **model["candidate_extractor"])

    # Configure logging for Fonduer
    init_logging(log_dir="logs")

    pyfunc_conf = _get_flavor_configuration(model_path=model_path,
                                            flavor_name=pyfunc.FLAVOR_NAME)
    candidate_classes = fonduer_model.candidate_extractor.candidate_classes

    fonduer_model.model_type = pyfunc_conf.get(MODEL_TYPE, "emmental")
    if fonduer_model.model_type == "emmental":
        emmental.init()
        fonduer_model.featurizer = FeaturizerUDF(candidate_classes,
                                                 FeatureExtractor())
        fonduer_model.key_names = model["feature_keys"]
        fonduer_model.word2id = model["word2id"]

        # Load the emmental_model
        buffer = BytesIO()
        buffer.write(model["emmental_model"])
        buffer.seek(0)
        fonduer_model.emmental_model = torch.load(buffer)
    else:
        fonduer_model.labeler = LabelerUDF(candidate_classes)
        fonduer_model.key_names = model["labeler_keys"]

        fonduer_model.lfs = model["lfs"]

        fonduer_model.label_models = []
        for state_dict in model["label_models_state_dict"]:
            label_model = LabelModel()
            label_model.__dict__.update(state_dict)
            fonduer_model.label_models.append(label_model)
    return fonduer_model
コード例 #12
0
def test_mention_longest_match():
    """Test longest match filtering in mention extraction."""
    file_name = "lincoln_short"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    # Mention Extraction
    name_ngrams = MentionNgramsPart(n_max=3)
    place_ngrams = MentionNgramsTemp(n_max=4)

    Name = mention_subclass("Name")
    Place = mention_subclass("Place")

    def is_birthplace_table_row(mention):
        if not mention.sentence.is_tabular():
            return False
        ngrams = get_row_ngrams(mention, lower=True)
        if "birth_place" in ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=False
    )
    mention_extractor_udf = MentionExtractorUDF(
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    doc = mention_extractor_udf.apply(doc)
    mentions = doc.places
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" in mention_spans
    assert len(mention_spans) == 23

    # Clear manually
    for mention in doc.places[:]:
        doc.places.remove(mention)

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=True
    )
    mention_extractor_udf = MentionExtractorUDF(
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    doc = mention_extractor_udf.apply(doc)
    mentions = doc.places
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" not in mention_spans
    assert len(mention_spans) == 4
コード例 #13
0
def name_extract_server(document, name_subclass):
    length_name_matcher = LambdaFunctionMatcher(func=length_name)
    position_name_matcher = LambdaFunctionMatcher(func=position_name)
    capitalize_name_matcher = LambdaFunctionMatcher(func=capitalize_name)

    last_name_matcher = LambdaFunctionMatcher(func=last_name)
    name_common_matcher = LambdaFunctionMatcher(func=name_common)
    check_name_matcher = LambdaFunctionMatcher(func=check_name)
    prefix_name_matcher = LambdaFunctionMatcher(func=prefix_name)

    form_name_matcher = Intersect(length_name_matcher, position_name_matcher,
                                  capitalize_name_matcher)
    name_matcher = Intersect(
        Union(Intersect(last_name_matcher, form_name_matcher),
              Intersect(name_common_matcher, form_name_matcher),
              prefix_name_matcher), check_name_matcher)

    name_space = MentionName()

    document = MentionExtractorUDF([name_subclass], [name_space],
                                   [name_matcher]).apply(document)
    return document
コード例 #14
0
def test_multinary_relation_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)
    volt_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor_udf = MentionExtractorUDF(
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.parts) == 62
    assert len(doc.temps) == 16
    assert len(doc.volts) == 33
    part = doc.parts[0]
    temp = doc.temps[0]
    volt = doc.volts[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")
    logger.info(f"Volt: {volt.context}")

    # Candidate Extraction
    PartTempVolt = candidate_subclass("PartTempVolt", [Part, Temp, Volt])

    candidate_extractor_udf = CandidateExtractorUDF([PartTempVolt], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Manually set id as it is not set automatically b/c a database is not used.
    i = 0
    for cand in doc.part_temp_volts:
        cand.id = i
        i = i + 1

    n_cands = len(doc.part_temp_volts)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartTempVolt], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_w_customized_features = len(key_set)

    # Example spurious feature extractor
    def bad_feat_ext(candidates):
        raise RuntimeError()

    # Featurization with a spurious feature extractor
    feature_extractors = FeatureExtractor(
        customize_feature_funcs=[bad_feat_ext])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    logger.info("Featurizing with a spurious feature extractor...")
    with pytest.raises(RuntimeError):
        features = featurizer_udf.apply(doc)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands
コード例 #15
0
def test_unary_relation_feature_extraction():
    """Test extracting unary candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")

    mention_extractor_udf = MentionExtractorUDF([Part], [part_ngrams],
                                                [part_matcher])
    doc = mention_extractor_udf.apply(doc)

    assert doc.name == "112823"
    assert len(doc.parts) == 62
    part = doc.parts[0]
    logger.info(f"Part: {part.context}")

    # Candidate Extraction
    PartRel = candidate_subclass("PartRel", [Part])

    candidate_extractor_udf = CandidateExtractorUDF([PartRel], None, False,
                                                    False, True)
    doc = candidate_extractor_udf.apply(doc, split=0)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartRel], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)
コード例 #16
0
def test_cand_gen():
    """Test extracting candidates from mentions from documents."""

    def do_nothing_matcher(fig):
        return True

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"
    doc = parse_doc(docs_path, "112823", pdf_path)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
        MentionExtractor(
            "dummy",
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        MentionExtractor(
            "dummy",
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor_udf = MentionExtractorUDF(
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.parts) == 70
    assert len(doc.volts) == 33
    assert len(doc.temps) == 23
    assert len(doc.figs) == 31
    part = doc.parts[0]
    volt = doc.volts[0]
    temp = doc.temps[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Volt: {volt.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
        CandidateExtractor(
            "dummy",
            [PartTemp, PartVolt],
            throttlers=[
                temp_throttler,
                volt_throttler,
                volt_throttler,
            ],  # Fail, mismatched arity
        )

    with pytest.raises(ValueError):
        CandidateExtractor(
            "dummy",
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],
        )

    # Test that no throttler in candidate extractor
    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [None, None], False, False, True  # Pass, no throttler
    )

    doc = candidate_extractor_udf.apply(doc, split=0)

    assert len(doc.part_temps) == 1610
    assert len(doc.part_volts) == 2310

    # Clear
    doc.part_temps = []
    doc.part_volts = []

    # Test with None in throttlers in candidate extractor
    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [temp_throttler, None], False, False, True
    )

    doc = candidate_extractor_udf.apply(doc, split=0)
    assert len(doc.part_temps) == 1432
    assert len(doc.part_volts) == 2310

    # Clear
    doc.part_temps = []
    doc.part_volts = []

    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [temp_throttler, volt_throttler], False, False, True
    )

    doc = candidate_extractor_udf.apply(doc, split=0)

    assert len(doc.part_temps) == 1432
    assert len(doc.part_volts) == 1993
    assert len(doc.parts) == 70
    assert len(doc.volts) == 33
    assert len(doc.temps) == 23