Exemplo n.º 1
0
def test_pdf_word_list_is_sorted():
    """Test if pdf_word_list is sorted as expected.

    no_image_unsorted.html is originally created from pdf_simple/no_image.pdf,
    but the order of html elements like block and word has been changed to see if
    pdf_word_list is sorted as expected.
    """
    docs_path = "tests/data/html_simple/no_image_unsorted.html"
    pdf_path = "tests/data/pdf_simple"
    visual_parser = PdfVisualParser(pdf_path=pdf_path)
    with open(docs_path) as f:
        soup = BeautifulSoup(f, "html.parser")
    page = soup.find_all("page")[0]
    pdf_word_list, coordinate_map = visual_parser._coordinates_from_HTML(
        page, 1)

    # Check if words are sorted by block top
    assert set([content
                for (_, content) in pdf_word_list[:2]]) == {"Sample", "HTML"}
    # Check if words are sorted by top
    assert [content for (_, content) in pdf_word_list[2:7]] == [
        "This",
        "is",
        "an",
        "html",
        "that",
    ]
    # Check if words are sorted by left (#449)
    assert [content
            for (_, content) in pdf_word_list[:2]] == ["Sample", "HTML"]
Exemplo n.º 2
0
def test_parse_wo_tabular():
    """Test the parser without extracting tabular information."""
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "md"))

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=False,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
        language="en",
    )
    doc = parser_udf.apply(doc)

    # Check that doc has neither table nor cell
    assert len(doc.sections) == 1
    assert len(doc.paragraphs) == 44
    assert len(doc.figures) == 1
    assert len(doc.tables) == 0
    assert len(doc.cells) == 0
    assert len(doc.sentences) == 45

    # Check that sentences are associated with both section and paragraph.
    assert all([sent.section is not None for sent in doc.sentences])
    assert all([sent.paragraph is not None for sent in doc.sentences])

    # Check that sentences are NOT associated with cell
    assert all([sent.cell is None for sent in doc.sentences])
Exemplo n.º 3
0
def test_visual_parser_not_affected_by_order_of_sentences():
    """Test if visual_parser result is not affected by the order of sentences."""
    docs_path = "tests/data/html/2N6427.html"
    pdf_path = "tests/data/pdf/"

    # Initialize preprocessor, parser, visual_parser.
    # Note that parser is initialized with `visual=False` and that visual_parser
    # will be used to attach "visual" information to sentences after parsing.
    preprocessor = HTMLDocPreprocessor(docs_path)
    parser_udf = get_parser_udf(structural=True,
                                lingual=False,
                                tabular=True,
                                visual=False)
    visual_parser = PdfVisualParser(pdf_path=pdf_path)

    doc = parser_udf.apply(next(preprocessor.__iter__()))
    # Sort sentences by sentence.position
    doc.sentences = sorted(doc.sentences, key=attrgetter("position"))
    sentences0 = [
        sent for sent in visual_parser.parse(doc.name, doc.sentences)
    ]
    # Sort again in case visual_parser.link changes the order
    sentences0 = sorted(sentences0, key=attrgetter("position"))

    doc = parser_udf.apply(next(preprocessor.__iter__()))
    # Shuffle
    random.shuffle(doc.sentences)
    sentences1 = [
        sent for sent in visual_parser.parse(doc.name, doc.sentences)
    ]
    # Sort sentences by sentence.position
    sentences1 = sorted(sentences1, key=attrgetter("position"))

    # This should hold as both sentences are sorted by their position
    assert all([
        sent0.position == sent1.position
        for (sent0, sent1) in zip(sentences0, sentences1)
    ])

    # The following assertion should hold if the visual_parser result is not affected
    # by the order of sentences.
    assert all([
        sent0.left == sent1.left
        for (sent0, sent1) in zip(sentences0, sentences1)
    ])
Exemplo n.º 4
0
def test_parse_style():
    """Test style tag parsing."""
    logger = logging.getLogger(__name__)

    docs_path = "tests/data/html_extended/ext_diseases.html"
    pdf_path = "tests/data/pdf_extended/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "ext_diseases"))

    # Create an Parser and parse the diseases document
    parser_udf = get_parser_udf(
        structural=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    doc = parser_udf.apply(doc)

    # Grab the sentences parsed by the Parser
    sentences = doc.sentences

    logger.warning(f"Doc: {doc}")
    for i, sentence in enumerate(sentences):
        logger.warning(f"    Sentence[{i}]: {sentence.html_attrs}")

    # sentences for testing
    sub_sentences = [
        {
            "index":
            6,
            "attr": [
                "class=col-header",
                "hobbies=work:hard;play:harder",
                "type=phenotype",
                "style=background: #f1f1f1; color: aquamarine; font-size: 18px;",
            ],
        },
        {
            "index": 9,
            "attr": ["class=row-header", "style=background: #f1f1f1;"]
        },
        {
            "index": 11,
            "attr": ["class=cell", "style=text-align: center;"]
        },
    ]

    # Assertions
    assert all(sentences[p["index"]].html_attrs == p["attr"]
               for p in sub_sentences)
Exemplo n.º 5
0
def parse(session: Session, docs_path: str, pdf_path: str) -> List[Document]:
    """Parse documents using Parser UDF Runner."""
    # Preprocessor for the Docs
    doc_preprocessor = HTMLDocPreprocessor(docs_path)

    # Create an Parser and parse the documents
    corpus_parser = Parser(
        session,
        parallelism=1,
        structural=True,
        lingual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )

    corpus_parser.clear()
    corpus_parser.apply(doc_preprocessor)
    return corpus_parser.get_documents()
Exemplo n.º 6
0
def parse_doc(docs_path: str, file_name: str, pdf_path: Optional[str] = None):
    """Parse documents from given path."""
    max_docs = 1

    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    doc = next(doc_preprocessor._parse_file(docs_path, file_name))

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=True,
        lingual=True,
        visual=True if pdf_path else False,
        visual_parser=PdfVisualParser(pdf_path) if pdf_path else None,
        language="en",
    )
    doc = parser_udf.apply(doc)
    return doc
Exemplo n.º 7
0
def test_simple_parser():
    """Unit test of Parser on a single document with lingual features off."""
    logger = logging.getLogger(__name__)

    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "md"))

    # Check that doc has a name
    assert doc.name == "md"

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        lingual=False,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
        lingual_parser=SimpleParser(delim="NoDelim"),
    )
    doc = parser_udf.apply(doc)

    logger.info(f"Doc: {doc}")
    for i, sentence in enumerate(doc.sentences):
        logger.info(f"    Sentence[{i}]: {sentence.text}")

    header = sorted(doc.sentences, key=lambda x: x.position)[0]
    # Test structural attributes
    assert header.xpath == "/html/body/h1"
    assert header.html_tag == "h1"
    assert header.html_attrs == ["id=sample-markdown"]

    # Test lingual attributes
    assert header.ner_tags == ["", ""]
    assert header.dep_labels == ["", ""]
    assert header.dep_parents == [0, 0]
    assert header.lemmas == ["", ""]
    assert header.pos_tags == ["", ""]

    assert len(doc.sentences) == 44
Exemplo n.º 8
0
def test_warning_on_incorrect_filename():
    """Test that a warning is issued on invalid pdf."""
    docs_path = "tests/data/html_simple/md_para.html"
    pdf_path = "tests/data/html_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "md_para"))

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    with pytest.warns(RuntimeWarning) as record:
        doc = parser_udf.apply(doc)
    assert isinstance(record, type(pytest.warns(RuntimeWarning)))
Exemplo n.º 9
0
def test_warning_on_missing_pdf():
    """Test that a warning is issued on invalid pdf."""
    docs_path = "tests/data/html_simple/table_span.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(iter(preprocessor))

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    with pytest.warns(RuntimeWarning) as record:
        doc = parser_udf.apply(doc)
    assert len(record) == 1
    assert "Visual parse failed" in record[0].message.args[0]
Exemplo n.º 10
0
def test_parser_no_image():
    """Unit test of Parser on a single document that has a figure without image."""
    docs_path = "tests/data/html_simple/no_image.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "no_image"))

    # Check that doc has a name
    assert doc.name == "no_image"

    # Create an Parser and parse the no_image document
    parser_udf = get_parser_udf(
        structural=True,
        lingual=False,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    doc = parser_udf.apply(doc)

    # Check that doc has no figures
    assert len(doc.figures) == 0
Exemplo n.º 11
0
def mention_setup():
    """Set up mentions."""
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor.__iter__())

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
        language="en",
    )
    doc = parser_udf.apply(doc)

    # Create 1-gram span mentions
    space = MentionNgrams(n_min=1, n_max=1)
    mentions = [tc for tc in space.apply(doc)]
    return mentions
Exemplo n.º 12
0
def test_parse_md_details():
    """Test the parser with the md document."""
    logger = logging.getLogger(__name__)

    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "md"))

    # Check that doc has a name
    assert doc.name == "md"

    # Check that doc does not have any of these
    assert len(doc.figures) == 0
    assert len(doc.tables) == 0
    assert len(doc.cells) == 0
    assert len(doc.sentences) == 0

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
        language="en",
    )
    doc = parser_udf.apply(doc)

    # Check that doc has a figure
    assert len(doc.figures) == 1
    assert doc.figures[0].url == "http://placebear.com/200/200"
    assert doc.figures[0].position == 0
    assert doc.figures[0].section.position == 0
    assert doc.figures[0].stable_id == "md::figure:0"

    #  Check that doc has a table
    assert len(doc.tables) == 1
    assert doc.tables[0].position == 0
    assert doc.tables[0].section.position == 0
    assert doc.tables[0].document.name == "md"

    # Check that doc has cells
    assert len(doc.cells) == 16
    cells = list(doc.cells)
    assert cells[0].row_start == 0
    assert cells[0].col_start == 0
    assert cells[0].position == 0
    assert cells[0].document.name == "md"
    assert cells[0].table.position == 0

    assert cells[10].row_start == 2
    assert cells[10].col_start == 2
    assert cells[10].position == 10
    assert cells[10].document.name == "md"
    assert cells[10].table.position == 0

    # Check that doc has sentences
    assert len(doc.sentences) == 45
    sent = sorted(doc.sentences, key=lambda x: x.position)[25]
    assert sent.text == "Spicy"
    assert sent.table.position == 0
    assert sent.table.section.position == 0
    assert sent.cell.row_start == 0
    assert sent.cell.col_start == 2

    # Check if the tail text is processed after inner elements (#333)
    assert [sent.text for sent in doc.sentences[14:18]] == [
        "italics and later",
        "bold",
        ".",
        "Even",
    ]

    # Check abs_char_offsets (#332)
    text = "".join([sent.text for sent in doc.sentences])
    for sent in doc.sentences:
        for abs_char_offset, word in zip(sent.abs_char_offsets, sent.words):
            assert text[abs_char_offset] == word[0]

    logger.info(f"Doc: {doc}")
    for i, sentence in enumerate(doc.sentences):
        logger.info(f"    Sentence[{i}]: {sentence.text}")

    header = sorted(doc.sentences, key=lambda x: x.position)[0]
    # Test structural attributes
    assert header.xpath == "/html/body/h1"
    assert header.html_tag == "h1"
    assert header.html_attrs == ["id=sample-markdown"]

    # Test visual attributes
    assert header.page == [1, 1]
    assert header.top == [35, 35]
    assert header.bottom == [61, 61]
    assert header.right == [111, 231]
    assert header.left == [35, 117]
    # Test lingual attributes
    # when lingual=True, some value other than "" should be filled-in.
    assert all(sent.ner_tags)
    assert all(sent.dep_labels)

    # Test whether nlp information corresponds to sentence words
    for sent in doc.sentences:
        assert len(sent.words) == len(sent.lemmas)
        assert len(sent.words) == len(sent.pos_tags)
        assert len(sent.words) == len(sent.ner_tags)
        assert len(sent.words) == len(sent.dep_parents)
        assert len(sent.words) == len(sent.dep_labels)
Exemplo n.º 13
0
def test_parse_document_diseases():
    """Unit test of Parser on a single document.

    This tests both the structural and visual parse of the document.
    """
    logger = logging.getLogger(__name__)

    docs_path = "tests/data/html_simple/diseases.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "diseases"))

    # Check that doc has a name
    assert doc.name == "diseases"

    # Create an Parser and parse the diseases document
    parser_udf = get_parser_udf(
        structural=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    doc = parser_udf.apply(doc)

    logger.info(f"Doc: {doc}")
    for sentence in doc.sentences:
        logger.info(f"    Sentence: {sentence.text}")

    # Check captions
    assert len(doc.captions) == 2
    caption = sorted(doc.sentences, key=lambda x: x.position)[20]
    assert caption.paragraph.caption.position == 0
    assert caption.paragraph.caption.table.position == 0
    assert caption.text == "Table 1: Infectious diseases and where to find them."
    assert caption.paragraph.position == 18

    # Check figures
    assert len(doc.figures) == 0

    #  Check that doc has a table
    assert len(doc.tables) == 3
    assert doc.tables[0].position == 0
    assert doc.tables[0].document.name == "diseases"

    # Check that doc has cells
    assert len(doc.cells) == 25

    sentence = sorted(doc.sentences, key=lambda x: x.position)[10]
    logger.info(f"  {sentence}")

    # Check the sentence's cell
    assert sentence.table.position == 0
    assert sentence.cell.row_start == 2
    assert sentence.cell.col_start == 1
    assert sentence.cell.position == 4

    # Test structural attributes
    assert sentence.xpath == "/html/body/table[1]/tbody/tr[3]/td[1]/p"
    assert sentence.html_tag == "p"
    assert sentence.html_attrs == ["class=s6", "style=padding-top: 1pt"]

    # Test visual attributes
    assert sentence.page == [1, 1, 1]
    assert sentence.top == [342, 296, 356]
    assert sentence.left == [318, 369, 318]

    # Test lingual attributes
    # when lingual=True, some value other than "" should be filled-in.
    assert all(sentence.ner_tags)
    assert all(sentence.dep_labels)

    assert len(doc.sentences) == 37
Exemplo n.º 14
0
def test_incremental(database_session):
    """Run an end-to-end test on incremental additions."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 1

    session = database_session

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()
    last_docs = corpus_parser.get_documents()

    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 8

    # Test if clear=True works
    mention_extractor.apply(docs, parallelism=PARALLEL, clear=True)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 8

    # Candidate Extraction
    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70
    assert session.query(Candidate).count() == 70

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 70

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == len(train_cands[0])
    num_feature_keys = session.query(FeatureKey).count()
    assert num_feature_keys == 514

    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
    assert len(featurizer.get_keys()) == num_feature_keys

    # Test Dropping FeatureKey
    featurizer.drop_keys(["BASIC_e1_LENGTH_1"])
    assert session.query(FeatureKey).count() == num_feature_keys - 1

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == len(train_cands[0])

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (len(train_cands[0]), 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)
    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Test if existing mentions are skipped.
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)
    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1501

    # Test if existing candidates are skipped.
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1501

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == len(train_cands[0])
    num_feature_keys = session.query(FeatureKey).count()
    assert num_feature_keys == 2608
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
    assert len(featurizer.get_keys()) == num_feature_keys

    # Update LF_storage_row. Now it always returns ABSTAIN.
    @labeling_function(name="LF_storage_row")
    def LF_storage_row_updated(c):
        return ABSTAIN

    stg_temp_lfs = [
        LF_storage_row_updated,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    # Update Labels
    labeler.update(docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == len(train_cands[0])
    # Only 5 because LF_storage_row doesn't apply to any doc (always ABSTAIN)
    num_label_keys = session.query(LabelKey).count()
    assert num_label_keys == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (len(train_cands[0]), num_label_keys)

    # Test clear
    featurizer.clear(train=True)
    assert session.query(FeatureKey).count() == 0
Exemplo n.º 15
0
def test_too_many_clients_error_should_not_happen(database_session):
    """Too many clients error should not happens."""
    PARALLEL = 32
    logger.info("Parallel: {PARALLEL}")

    def do_nothing_matcher(fig):
        return True

    max_docs = 1
    session = database_session

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session,
        structural=True,
        lingual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt]
    )  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
    candidate_extractor.clear_all(split=0)

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, None]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
Exemplo n.º 16
0
def test_parse_md_paragraphs():
    """Unit test of Paragraph parsing."""
    docs_path = "tests/data/html_simple/md_para.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor._parse_file(docs_path, "md_para"))

    # Check that doc has a name
    assert doc.name == "md_para"

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    doc = parser_udf.apply(doc)

    # Check that doc has a figure
    assert len(doc.figures) == 6
    assert doc.figures[0].url == "http://placebear.com/200/200"
    assert doc.figures[0].position == 0
    assert doc.figures[0].section.position == 0
    assert len(doc.figures[0].captions) == 0
    assert doc.figures[0].stable_id == "md_para::figure:0"
    assert doc.figures[0].cell.position == 13
    assert (doc.figures[2].url ==
            "http://html5doctor.com/wp-content/uploads/2010/03/kookaburra.jpg")
    assert doc.figures[2].position == 2
    assert len(doc.figures[2].captions) == 1
    assert len(doc.figures[2].captions[0].paragraphs[0].sentences) == 3
    assert (doc.figures[2].captions[0].paragraphs[0].sentences[0].text ==
            "Australian Birds.")
    assert len(doc.figures[4].captions) == 0
    assert (doc.figures[4].url ==
            "http://html5doctor.com/wp-content/uploads/2010/03/pelican.jpg")

    #  Check that doc has a table
    assert len(doc.tables) == 1
    assert doc.tables[0].position == 0
    assert doc.tables[0].section.position == 0

    # Check that doc has cells
    assert len(doc.cells) == 16
    cells = list(doc.cells)
    assert cells[0].row_start == 0
    assert cells[0].col_start == 0
    assert cells[0].position == 0
    assert cells[0].table.position == 0

    assert cells[10].row_start == 2
    assert cells[10].col_start == 2
    assert cells[10].position == 10
    assert cells[10].table.position == 0

    # Check that doc has sentences
    assert len(doc.sentences) == 51
    sentences = sorted(doc.sentences, key=lambda x: x.position)
    sent1 = sentences[1]
    sent2 = sentences[2]
    sent3 = sentences[3]
    assert sent1.text == "This is some basic, sample markdown."
    assert sent2.text == (
        "Unlike the other markdown document, however, "
        "this document actually contains paragraphs of text.")
    assert sent1.paragraph.position == 1
    assert sent1.section.position == 0
    assert sent2.paragraph.position == 1
    assert sent2.section.position == 0
    assert sent3.paragraph.position == 1
    assert sent3.section.position == 0

    assert len(doc.paragraphs) == 46
    assert len(doc.paragraphs[1].sentences) == 3
    assert len(doc.paragraphs[2].sentences) == 1
Exemplo n.º 17
0
def test_cand_gen_cascading_delete(database_session):
    """Test cascading the deletion of candidates."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 1
    session = database_session

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session,
        structural=True,
        lingual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )
    mention_extractor.clear_all()
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 93
    assert session.query(Part).count() == 70
    assert session.query(Temp).count() == 23
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1431
    assert session.query(Candidate).count() == 1431
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 23

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(Candidate).count() == 1430
    assert session.query(PartTemp).count() == 1430

    # Test that deletion of a Candidate does not delete the Mention
    x = session.query(PartTemp).first()
    candidate = session.query(PartTemp).filter_by(id=x.id).first()
    session.delete(candidate)
    assert session.query(PartTemp).count() == 1429
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
Exemplo n.º 18
0
def test_non_existent_pdf_path_should_fail():
    """Test if a non-existent raises an error."""
    pdf_path = "dummy_path"
    with pytest.raises(ValueError):
        PdfVisualParser(pdf_path=pdf_path)
Exemplo n.º 19
0
def test_e2e(database_session):
    """Run an end-to-end test on documents of the hardware domain."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 12

    fonduer.init_logging(
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = database_session

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions(docs)) == 3

    # Candidate Extraction
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)

    # Candidate lists should be deterministically sorted.
    assert ("112823::implicit_span_mention:11059:11065:part_expander:0" ==
            train_cands[0][0][0].context.get_stable_id())
    assert ("112823::implicit_span_mention:2752:2754:temp_expander:0" ==
            train_cands[0][0][1].context.get_stable_id())

    assert len(train_cands) == 2
    assert len(candidate_extractor.get_candidates(docs)) == 2

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    num_feature_keys = session.query(FeatureKey).count()
    assert num_feature_keys == 1278

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["BASIC_e0_CONTAINS_WORDS_[BC182]"])
    assert session.query(FeatureKey).count() == num_feature_keys - 1

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == num_feature_keys - 1

    # Inserting the removed key
    featurizer.upsert_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                           candidate_classes=[PartTemp, PartVolt])
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    assert session.query(FeatureKey).count() == num_feature_keys - 1
    # Removing the key again
    featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == num_feature_keys - 2
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    # the number of Features should equals to the total number of train candidates
    num_features = session.query(Feature).count()
    assert num_features == len(train_cands[0]) + len(train_cands[1])
    num_feature_keys = session.query(FeatureKey).count()
    assert num_feature_keys == 4629
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
    assert F_train[1].shape == (len(train_cands[1]), num_feature_keys)
    assert len(featurizer.get_keys()) == num_feature_keys

    featurizer.apply(split=1, parallelism=PARALLEL)
    # the number of Features should increate by the total number of dev candidates
    num_features += len(dev_cands[0]) + len(dev_cands[1])
    assert session.query(Feature).count() == num_features

    assert session.query(FeatureKey).count() == num_feature_keys
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (len(dev_cands[0]), num_feature_keys)
    assert F_dev[1].shape == (len(dev_cands[1]), num_feature_keys)

    featurizer.apply(split=2, parallelism=PARALLEL)
    # the number of Features should increate by the total number of test candidates
    num_features += len(test_cands[0]) + len(test_cands[1])
    assert session.query(Feature).count() == num_features
    assert session.query(FeatureKey).count() == num_feature_keys
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (len(test_cands[0]), num_feature_keys)
    assert F_test[1].shape == (len(test_cands[1]), num_feature_keys)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    # This should raise an error, since gold labels are not yet loaded.
    with pytest.raises(ValueError):
        _ = labeler.get_gold_labels(train_cands, annotator="gold")

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    # All candidates should now be gold-labeled.
    assert session.query(GoldLabel).count() == session.query(Candidate).count()

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == len(train_cands[0]) + len(
        train_cands[1])
    num_label_keys = session.query(LabelKey).count()
    assert num_label_keys == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (len(train_cands[0]), num_label_keys)
    assert L_train[1].shape == (len(train_cands[1]), num_label_keys)
    assert len(labeler.get_keys()) == num_label_keys

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == num_label_keys - 1

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (len(train_cands[0]), 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (len(train_cands[0]), 1)

    label_model = LabelModel(cardinality=2)
    label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    # Collect word counter
    word_counter = collect_word_counter(train_cands)

    emmental.init(fonduer.Meta.log_path)

    # Training config
    config = {
        "meta_config": {
            "verbose": False
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0],
                               emb_layer.word2id, 2),
        split="test",
        batch_size=100,
        shuffle=False,
    )

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == len(train_cands[0]) + len(
        train_cands[1])
    num_label_keys = session.query(LabelKey).count()
    assert num_label_keys == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (len(train_cands[0]), num_label_keys)

    label_model = LabelModel(cardinality=2)
    label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    valid_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            np.argmax(train_marginals, axis=1),
            train_idxs,
        ),
        split="valid",
        batch_size=100,
        shuffle=False,
    )

    # Testing STL LogisticRegression
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(
        ATTRIBUTE,
        2,
        F_train[0].shape[1],
        2,
        emb_layer,
        model="LogisticRegression",
        mode="STL",
    )

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing STL LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM",
                        mode="STL")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing MTL LogisticRegression
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(
        ATTRIBUTE,
        2,
        F_train[0].shape[1],
        2,
        emb_layer,
        model="LogisticRegression",
        mode="MTL",
    )

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM",
                        mode="MTL")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7