示例#1
0
def test_subclass_before_meta_init():
    """Test if mention (candidate) subclass can be created before Meta init."""
    # Test if mention (candidate) subclass can be created
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")

    # Setup a database
    con = psycopg2.connect(
        host=os.environ["POSTGRES_HOST"],
        port=os.environ["POSTGRES_PORT"],
        user=os.environ["PGUSER"],
        password=os.environ["PGPASSWORD"],
    )
    con.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    cursor = con.cursor()
    cursor.execute(f'create database "{DB}";')
    session = Meta.init(CONN_STRING).Session()

    # Test if another mention subclass can be created
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")

    # Teardown the database
    session.close()
    Meta.engine.dispose()
    Meta.engine = None

    cursor.execute(f'drop database "{DB}";')
    cursor.close()
    con.close()
示例#2
0
def get_subclasses(experiment):
  # 1.) Mention subclasses
  Data = mention_subclass("Data")
  Row = mention_subclass("Row")
  Col = mention_subclass("Col")

  # 2.) Mention spaces
  data_ngrams = MentionSentences() # MentionNgrams(n_max=3)
  row_ngrams = MentionSentences() # MentionNgrams(n_min=1, n_max=8)
  col_ngrams = MentionSentences() # MentionNgrams(n_min=1, n_max=8)

  # 3.) Matchers
  data_regex_matcher = RegexMatchSpan(rgx=r"[0-9-,.%$#]+( to | )?[0-9-,.%$#]*|^x$", longest_match_only=True)
  data_label_matcher = LambdaFunctionMatcher(func=get_label_matcher("Data", experiment))
  data_matcher = Intersect(data_regex_matcher, data_label_matcher)
  row_regex_matcher = RegexMatchSpan(rgx=r"^.*$", longest_match_only=True)
  row_label_matcher = LambdaFunctionMatcher(func=get_label_matcher("Header", experiment))
  row_matcher = Intersect(row_regex_matcher, row_label_matcher)
  col_regex_matcher = RegexMatchSpan(rgx=r"^.*$", longest_match_only=True)
  col_label_matcher = LambdaFunctionMatcher(func=get_label_matcher("Header", experiment))
  col_matcher = Intersect(col_regex_matcher, col_label_matcher)

  # 4.) Candidate classes
  RowCandidate = candidate_subclass("RowCandidate", [Data, Row])
  ColCandidate = candidate_subclass("ColCandidate", [Data, Col])

  # 5.) Throttlers
  mention_classes = [Data, Row, Col]
  mention_spaces = [data_ngrams, row_ngrams, col_ngrams]
  matchers = [data_matcher, row_matcher, col_matcher]
  candidate_classes = [RowCandidate, ColCandidate]
  throttlers = [row_filter, col_filter]

  return (mention_classes, mention_spaces, matchers, candidate_classes, throttlers)
示例#3
0
def test_subclass_before_meta_init():
    """Test if mention (candidate) subclass can be created before Meta init."""
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")
    Meta.init("postgresql://localhost:5432/" + DB).Session()
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
示例#4
0
def test_subclass_before_meta_init():
    """Test if it is possible to create a mention (candidate) subclass even before Meta
    is initialized.
    """
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")
    Meta.init(CONN_STRING).Session()
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
示例#5
0
def test_mention_longest_match():
    """Test longest match filtering in mention extraction."""
    file_name = "lincoln_short"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    # Mention Extraction
    name_ngrams = MentionNgramsPart(n_max=3)
    place_ngrams = MentionNgramsTemp(n_max=4)

    Name = mention_subclass("Name")
    Place = mention_subclass("Place")

    def is_birthplace_table_row(mention):
        if not mention.sentence.is_tabular():
            return False
        ngrams = get_row_ngrams(mention, lower=True)
        if "birth_place" in ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=False
    )
    mention_extractor_udf = MentionExtractorUDF(
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    doc = mention_extractor_udf.apply(doc)
    mentions = doc.places
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" in mention_spans
    assert len(mention_spans) == 23

    # Clear manually
    for mention in doc.places[:]:
        doc.places.remove(mention)

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=True
    )
    mention_extractor_udf = MentionExtractorUDF(
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    doc = mention_extractor_udf.apply(doc)
    mentions = doc.places
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" not in mention_spans
    assert len(mention_spans) == 4
示例#6
0
def test_subclass_before_meta_init(caplog):
    """Test if it is possible to create a mention (candidate) subclass even before Meta
    is initialized.
    """
    caplog.set_level(logging.INFO)

    conn_string = "postgresql://localhost:5432/" + DB
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")
    Meta.init(conn_string).Session()
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
示例#7
0
def test_multimodal_cand(caplog):
    """Test multimodal candidate generation"""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/radiology.html"

    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs

    assert session.query(Sentence).count() == 35
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction

    ms_doc = mention_subclass("m_doc")
    ms_sec = mention_subclass("m_sec")
    ms_tab = mention_subclass("m_tab")
    ms_fig = mention_subclass("m_fig")
    ms_cell = mention_subclass("m_cell")
    ms_para = mention_subclass("m_para")
    ms_cap = mention_subclass("m_cap")
    ms_sent = mention_subclass("m_sent")

    m_doc = MentionDocuments()
    m_sec = MentionSections()
    m_tab = MentionTables()
    m_fig = MentionFigures()
    m_cell = MentionCells()
    m_para = MentionParagraphs()
    m_cap = MentionCaptions()
    m_sent = MentionSentences()

    ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell]
    m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell]
    matchers = [DoNothingMatcher()] * 8

    mention_extractor = MentionExtractor(session,
                                         ms,
                                         m,
                                         matchers,
                                         parallelism=PARALLEL)

    mention_extractor.apply(docs)

    assert session.query(ms_doc).count() == 1
    assert session.query(ms_cap).count() == 2
    assert session.query(ms_sec).count() == 5
    assert session.query(ms_tab).count() == 2
    assert session.query(ms_fig).count() == 2
    assert session.query(ms_para).count() == 30
    assert session.query(ms_sent).count() == 35
    assert session.query(ms_cell).count() == 21
示例#8
0
def test_candidate_with_nullable_mentions():
    """Test if mentions can be NULL."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/"
    doc = parse_doc(docs_path, "112823", pdf_path)

    # Mention Extraction
    MentionTemp = mention_subclass("MentionTemp")
    temp_ngrams = MentionNgramsTemp(n_max=2)
    mention_extractor_udf = MentionExtractorUDF(
        [MentionTemp],
        [temp_ngrams],
        [temp_matcher],
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.mention_temps) == 23

    # Candidate Extraction
    CandidateTemp = candidate_subclass("CandidateTemp", [MentionTemp],
                                       nullables=[True])
    candidate_extractor_udf = CandidateExtractorUDF([CandidateTemp], [None],
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)
    # The number of extracted candidates should be that of mentions + 1 (NULL)
    assert len(doc.candidate_temps) == len(doc.mention_temps) + 1
    # Extracted candidates should include one with NULL mention.
    assert None in [c[0] for c in doc.candidate_temps]
示例#9
0
def test_row_col_ngram_extraction():
    """Test whether row/column ngrams list is empty, if mention is not in a table."""
    file_name = "lincoln_short"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    # Mention Extraction
    place_ngrams = MentionNgramsTemp(n_max=4)
    Place = mention_subclass("Place")

    def get_row_and_column_ngrams(mention):
        row_ngrams = list(get_row_ngrams(mention))
        col_ngrams = list(get_col_ngrams(mention))
        if not mention.sentence.is_tabular():
            assert len(row_ngrams) == 1 and row_ngrams[0] is None
            assert len(col_ngrams) == 1 and col_ngrams[0] is None
        else:
            assert not any(x is None for x in row_ngrams)
            assert not any(x is None for x in col_ngrams)
        if "birth_place" in row_ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(func=get_row_and_column_ngrams)
    mention_extractor_udf = MentionExtractorUDF(
        [Place], [place_ngrams], [birthplace_matcher]
    )

    doc = mention_extractor_udf.apply(doc)
示例#10
0
def test_ngrams():
    """Test ngram limits in mention extraction"""
    file_name = "lincoln_short"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    # Mention Extraction
    Person = mention_subclass("Person")
    person_ngrams = MentionNgrams(n_max=3)
    person_matcher = PersonMatcher()

    mention_extractor_udf = MentionExtractorUDF(
        [Person], [person_ngrams], [person_matcher]
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.persons) == 118
    mentions = doc.persons
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0

    # Test for unigram exclusion
    for mention in doc.persons[:]:
        doc.persons.remove(mention)
    assert len(doc.persons) == 0

    person_ngrams = MentionNgrams(n_min=2, n_max=3)
    mention_extractor_udf = MentionExtractorUDF(
        [Person], [person_ngrams], [person_matcher]
    )
    doc = mention_extractor_udf.apply(doc)
    assert len(doc.persons) == 69
    mentions = doc.persons
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
示例#11
0
def test_visualizer():
    from fonduer.utils.visualizer import Visualizer  # noqa
    """Unit test of visualizer using the md document.
    """
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/md.pdf"

    # Grab the md document
    doc = parse_doc(docs_path, "md", pdf_path)
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams],
                                                [organization_matcher])

    doc = mention_extractor_udf.apply(doc)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor_udf = CandidateExtractorUDF([Organization], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    cands = doc.organizations

    # Test visualizer
    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)
    vis.display_candidates([cands[0]])
示例#12
0
def test_save_subclasses():
    """Test if subclasses can be saved."""
    mention_class = mention_subclass("test_mention_class")
    _save_mention_classes([mention_class], "./")
    assert os.path.exists("./mention_classes.pkl")

    candidate_class = candidate_subclass("test_candidate_class", [mention_class])
    _save_candidate_classes([candidate_class], "./")
    assert os.path.exists("./candidate_classes.pkl")
示例#13
0
def test_pickle_subclasses():
    """Test if mention/candidate subclasses and their objects can be pickled."""
    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    logger.info("Test if mention/candidate subclasses are picklable")
    pickle.loads(pickle.dumps(Part))
    pickle.loads(pickle.dumps(Temp))
    pickle.loads(pickle.dumps(PartTemp))

    logger.info("Test if their objects are pickable")
    part = Part()
    temp = Temp()
    parttemp = PartTemp()
    pickle.loads(pickle.dumps(part))
    pickle.loads(pickle.dumps(temp))
    pickle.loads(pickle.dumps(parttemp))
示例#14
0
def test_visualizer(caplog):
    from fonduer.utils.visualizer import Visualizer  # noqa
    """Unit test of visualizer using the md document.
    """
    caplog.set_level(logging.INFO)
    session = Meta.init("postgresql://localhost:5432/" + ATTRIBUTE).Session()

    PARALLEL = 1
    max_docs = 1
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/md.pdf"

    # Preprocessor for the Docs
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    # Create an Parser and parse the md document
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    # Grab the md document
    doc = session.query(Document).order_by(Document.name).all()[0]
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor = MentionExtractor(session, [Org], [organization_ngrams],
                                         [organization_matcher])

    mention_extractor.apply([doc], parallelism=PARALLEL)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor = CandidateExtractor(session, [Organization])

    candidate_extractor.apply([doc], split=0, parallelism=PARALLEL)

    cands = session.query(Organization).filter(Organization.split == 0).all()

    # Test visualizer
    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)
    vis.display_candidates([cands[0]])
示例#15
0
def test_multimodal_cand():
    """Test multimodal candidate generation"""
    file_name = "radiology"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    assert len(doc.sentences) == 35

    # Mention Extraction

    ms_doc = mention_subclass("m_doc")
    ms_sec = mention_subclass("m_sec")
    ms_tab = mention_subclass("m_tab")
    ms_fig = mention_subclass("m_fig")
    ms_cell = mention_subclass("m_cell")
    ms_para = mention_subclass("m_para")
    ms_cap = mention_subclass("m_cap")
    ms_sent = mention_subclass("m_sent")

    m_doc = MentionDocuments()
    m_sec = MentionSections()
    m_tab = MentionTables()
    m_fig = MentionFigures()
    m_cell = MentionCells()
    m_para = MentionParagraphs()
    m_cap = MentionCaptions()
    m_sent = MentionSentences()

    ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell]
    m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell]
    matchers = [DoNothingMatcher()] * 8

    mention_extractor_udf = MentionExtractorUDF(ms, m, matchers)

    doc = mention_extractor_udf.apply(doc)

    assert len(doc.m_docs) == 1
    assert len(doc.m_caps) == 2
    assert len(doc.m_secs) == 5
    assert len(doc.m_tabs) == 2
    assert len(doc.m_figs) == 2
    assert len(doc.m_paras) == 30
    assert len(doc.m_sents) == 35
    assert len(doc.m_cells) == 21
示例#16
0
def test_visualizer():
    """Unit test of visualizer using the md document."""
    from fonduer.utils.visualizer import Visualizer, get_box  # noqa

    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Grab the md document
    doc = parse_doc(docs_path, "md", pdf_path)
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams],
                                                [organization_matcher])

    doc = mention_extractor_udf.apply(doc)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor_udf = CandidateExtractorUDF([Organization], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Take one candidate
    cand = doc.organizations[0]

    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)

    # Test bounding boxes
    boxes = [get_box(mention.context) for mention in cand.get_mentions()]
    for box in boxes:
        assert box.top <= box.bottom
        assert box.left <= box.right
    assert boxes == [
        mention.context.get_bbox() for mention in cand.get_mentions()
    ]

    # Test visualizer
    vis.display_candidates([cand])
示例#17
0
def test_ngrams(caplog):
    """Test ngram limits in mention extraction"""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/lincoln_short.html"

    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 503
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    Person = mention_subclass("Person")
    person_ngrams = MentionNgrams(n_max=3)
    person_matcher = PersonMatcher()

    mention_extractor = MentionExtractor(
        session, [Person], [person_ngrams], [person_matcher]
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Person).count() == 118
    mentions = session.query(Person).all()
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0

    # Test for unigram exclusion
    person_ngrams = MentionNgrams(n_min=2, n_max=3)
    mention_extractor = MentionExtractor(
        session, [Person], [person_ngrams], [person_matcher]
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)
    assert session.query(Person).count() == 69
    mentions = session.query(Person).all()
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
示例#18
0
def test_row_col_ngram_extraction(caplog):
    """Test whether row/column ngrams list is empty, if mention is not in a table."""
    caplog.set_level(logging.INFO)
    PARALLEL = 1
    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()
    docs_path = "tests/data/pure_html/lincoln_short.html"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    place_ngrams = MentionNgramsTemp(n_max=4)
    Place = mention_subclass("Place")

    def get_row_and_column_ngrams(mention):
        row_ngrams = list(get_row_ngrams(mention))
        col_ngrams = list(get_col_ngrams(mention))
        if not mention.sentence.is_tabular():
            assert len(row_ngrams) == 1 and row_ngrams[0] is None
            assert len(col_ngrams) == 1 and col_ngrams[0] is None
        else:
            assert not any(x is None for x in row_ngrams)
            assert not any(x is None for x in col_ngrams)
        if "birth_place" in row_ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(func=get_row_and_column_ngrams)
    mention_extractor = MentionExtractor(
        session, [Place], [place_ngrams], [birthplace_matcher]
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)
示例#19
0
def test_multinary_relation_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)
    volt_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor_udf = MentionExtractorUDF(
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.parts) == 62
    assert len(doc.temps) == 16
    assert len(doc.volts) == 33
    part = doc.parts[0]
    temp = doc.temps[0]
    volt = doc.volts[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")
    logger.info(f"Volt: {volt.context}")

    # Candidate Extraction
    PartTempVolt = candidate_subclass("PartTempVolt", [Part, Temp, Volt])

    candidate_extractor_udf = CandidateExtractorUDF([PartTempVolt], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Manually set id as it is not set automatically b/c a database is not used.
    i = 0
    for cand in doc.part_temp_volts:
        cand.id = i
        i = i + 1

    n_cands = len(doc.part_temp_volts)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartTempVolt], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_w_customized_features = len(key_set)

    # Example spurious feature extractor
    def bad_feat_ext(candidates):
        raise RuntimeError()

    # Featurization with a spurious feature extractor
    feature_extractors = FeatureExtractor(
        customize_feature_funcs=[bad_feat_ext])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    logger.info("Featurizing with a spurious feature extractor...")
    with pytest.raises(RuntimeError):
        features = featurizer_udf.apply(doc)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands
示例#20
0
def test_unary_relation_feature_extraction():
    """Test extracting unary candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")

    mention_extractor_udf = MentionExtractorUDF([Part], [part_ngrams],
                                                [part_matcher])
    doc = mention_extractor_udf.apply(doc)

    assert doc.name == "112823"
    assert len(doc.parts) == 62
    part = doc.parts[0]
    logger.info(f"Part: {part.context}")

    # Candidate Extraction
    PartRel = candidate_subclass("PartRel", [Part])

    candidate_extractor_udf = CandidateExtractorUDF([PartRel], None, False,
                                                    False, True)
    doc = candidate_extractor_udf.apply(doc, split=0)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartRel], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)
示例#21
0
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    PARALLEL = 4

    max_docs = 12

    fonduer.init_logging(
        log_dir="log_folder",
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (
        len(
            mention_extractor.get_mentions(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 70
    )

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]
    )

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (
        len(
            candidate_extractor.get_candidates(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 1432
    )

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey)
        .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]")
        .one()
        .candidate_classes
    ) == {"part_temp", "part_volt"}
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]"
    ).one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(
        ["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt]
    )
    assert set(
        session.query(FeatureKey)
        .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]")
        .one()
        .candidate_classes
    ) == {"part_temp", "part_volt"}
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    gen_model = LabelModel()
    gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]),
        train_marginals,
        X_dev=(train_cands[0], F_train[0]),
        Y_dev=L_train_gold[0].reshape(-1),
        b=0.6,
        pos_label=TRUE,
        n_epochs=5,
        lr=0.001,
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    gen_model = LabelModel()
    gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    disc_model = LSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse Logistic Regression
    disc_model = SparseLogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse LSTM
    disc_model = SparseLSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Evaluate mention level scores
    L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")
    Y_test = L_test_gold[0].reshape(-1)

    scores = disc_model.score((test_cands[0], F_test[0]), Y_test, b=0.6, pos_label=TRUE)

    logger.info(scores)

    assert scores["f1"] > 0.6
示例#22
0
def test_mention_longest_match(caplog):
    """Test longest match filtering in mention extraction."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parsing
    PARALLEL = 1

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/lincoln_short.html"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    docs = session.query(Document).order_by(Document.name).all()
    # Mention Extraction
    name_ngrams = MentionNgramsPart(n_max=3)
    place_ngrams = MentionNgramsTemp(n_max=4)

    Name = mention_subclass("Name")
    Place = mention_subclass("Place")

    def is_birthplace_table_row(mention):
        if not mention.sentence.is_tabular():
            return False
        ngrams = get_row_ngrams(mention, lower=True)
        if "birth_place" in ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=False
    )
    mention_extractor = MentionExtractor(
        session,
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)
    mentions = session.query(Place).all()
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" in mention_spans
    assert len(mention_spans) == 23

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=True
    )
    mention_extractor = MentionExtractor(
        session,
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)
    mentions = session.query(Place).all()
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" not in mention_spans
    assert len(mention_spans) == 4
示例#23
0
def test_cand_gen(caplog):
    """Test extracting candidates from mentions from documents."""
    caplog.set_level(logging.INFO)

    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using two cores.")
        PARALLEL = 2  # Travis only gives 2 cores

    def do_nothing_matcher(fig):
        return True

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 70
    assert session.query(Volt).count() == 33
    assert session.query(Temp).count() == 23
    assert session.query(Fig).count() == 31
    part = session.query(Part).order_by(Part.id).all()[0]
    volt = session.query(Volt).order_by(Volt.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Volt: {volt.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp, PartVolt],
            throttlers=[
                temp_throttler,
                volt_throttler,
                volt_throttler,
            ],  # Fail, mismatched arity
        )

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],
        )

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt]
    )  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1610
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3920
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(PartVolt).count() == 0

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, None]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3742
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 1993
    assert session.query(Candidate).count() == 3425
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].volts) == 33
    assert len(docs[0].temps) == 23

    # Test that deletion of a Candidate does not delete the Mention
    session.query(PartTemp).delete(synchronize_session="fetch")
    assert session.query(PartTemp).count() == 0
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Test deletion of Candidate if Mention is deleted
    assert session.query(PartVolt).count() == 1993
    assert session.query(Volt).count() == 33
    session.query(Volt).delete(synchronize_session="fetch")
    assert session.query(Volt).count() == 0
    assert session.query(PartVolt).count() == 0
示例#24
0
def test_cand_gen_cascading_delete(caplog):
    """Test cascading the deletion of candidates."""
    caplog.set_level(logging.INFO)

    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using two cores.")
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )
    mention_extractor.clear_all()
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 93
    assert session.query(Part).count() == 70
    assert session.query(Temp).count() == 23
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(Candidate).count() == 1432
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 23

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(Candidate).count() == 1431
    assert session.query(PartTemp).count() == 1431

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
示例#25
0
文件: test_e2e.py 项目: SenWu/fonduer
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 12

    fonduer.init_logging(
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (len(
        mention_extractor.get_mentions(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 70)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (len(
        candidate_extractor.get_candidates(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 1432)

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                           candidate_classes=[PartTemp, PartVolt])
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    # Collect word counter
    word_counter = collect_word_counter(train_cands)

    emmental.init(fonduer.Meta.log_path)

    # Training config
    config = {
        "meta_config": {
            "verbose": False
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0],
                               emb_layer.word2id, 2),
        split="test",
        batch_size=100,
        shuffle=False,
    )

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    valid_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            np.argmax(train_marginals, axis=1),
            train_idxs,
        ),
        split="valid",
        batch_size=100,
        shuffle=False,
    )

    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7
示例#26
0
def test_cand_gen_cascading_delete():
    """Test cascading the deletion of candidates."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )
    mention_extractor.clear_all()
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 93
    assert session.query(Part).count() == 70
    assert session.query(Temp).count() == 23
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(Candidate).count() == 1432
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 23

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(Candidate).count() == 1431
    assert session.query(PartTemp).count() == 1431

    # Test that deletion of a Candidate does not delete the Mention
    x = session.query(PartTemp).first()
    session.query(PartTemp).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(PartTemp).count() == 1430
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
示例#27
0
 def __init__(self, session):
     self.session = session
     self.Email = mention_subclass("Email")
     self.Email_C = candidate_subclass("Email_C", [self.Email])
示例#28
0
def test_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    PARALLEL = 1

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert docs[0].name == "112823"
    assert session.query(Part).count() == 58
    assert session.query(Temp).count() == 16
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    n_cands = session.query(PartTemp).count()

    # Featurization based on default feature library
    featurizer = Featurizer(session, [PartTemp])

    # Test that featurization default feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_feats = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_w_customized_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_textual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_tabular_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_structural_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_visual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands
示例#29
0
 def __init__(self):
     self.Email = mention_subclass("Email")
     self.Email_C = candidate_subclass("Email_C", [self.Email])
示例#30
0
def test_incremental(caplog):
    """Run an end-to-end test on incremental additions."""
    caplog.set_level(logging.INFO)

    PARALLEL = 1

    max_docs = 1

    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/dtc114w.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()
    last_docs = corpus_parser.get_documents()

    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 8

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70
    assert session.query(Candidate).count() == 70

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 70

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 70
    assert session.query(FeatureKey).count() == 512

    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (70, 512)
    assert len(featurizer.get_keys()) == 512

    # Test Dropping FeatureKey
    featurizer.drop_keys(["CORE_e1_LENGTH_1"])
    assert session.query(FeatureKey).count() == 512

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == 70

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (70, 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)

    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1502

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == 1502
    assert session.query(FeatureKey).count() == 2573
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (1502, 2573)
    assert len(featurizer.get_keys()) == 2573

    # Update Labels
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 1502
    assert session.query(LabelKey).count() == 6
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (1502, 6)

    # Test clear
    featurizer.clear(train=True)
    assert session.query(FeatureKey).count() == 0