示例#1
0
def test_subclass_before_meta_init():
    """Test if mention (candidate) subclass can be created before Meta init."""
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")
    Meta.init("postgresql://localhost:5432/" + DB).Session()
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
示例#2
0
def test_subclass_before_meta_init():
    """Test if it is possible to create a mention (candidate) subclass even before Meta
    is initialized.
    """
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")
    Meta.init(CONN_STRING).Session()
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
示例#3
0
def test_subclass_before_meta_init(caplog):
    """Test if it is possible to create a mention (candidate) subclass even before Meta
    is initialized.
    """
    caplog.set_level(logging.INFO)

    conn_string = "postgresql://localhost:5432/" + DB
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")
    Meta.init(conn_string).Session()
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")
示例#4
0
def test_subclass_before_meta_init():
    """Test if mention (candidate) subclass can be created before Meta init."""
    # Test if mention (candidate) subclass can be created
    Part = mention_subclass("Part")
    logger.info(f"Create a mention subclass '{Part.__tablename__}'")

    # Setup a database
    con = psycopg2.connect(
        host=os.environ["POSTGRES_HOST"],
        port=os.environ["POSTGRES_PORT"],
        user=os.environ["PGUSER"],
        password=os.environ["PGPASSWORD"],
    )
    con.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    cursor = con.cursor()
    cursor.execute(f'create database "{DB}";')
    session = Meta.init(CONN_STRING).Session()

    # Test if another mention subclass can be created
    Temp = mention_subclass("Temp")
    logger.info(f"Create a mention subclass '{Temp.__tablename__}'")

    # Teardown the database
    session.close()
    Meta.engine.dispose()
    Meta.engine = None

    cursor.execute(f'drop database "{DB}";')
    cursor.close()
    con.close()
示例#5
0
def test_multimodal_cand(caplog):
    """Test multimodal candidate generation"""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/radiology.html"

    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs

    assert session.query(Sentence).count() == 35
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction

    ms_doc = mention_subclass("m_doc")
    ms_sec = mention_subclass("m_sec")
    ms_tab = mention_subclass("m_tab")
    ms_fig = mention_subclass("m_fig")
    ms_cell = mention_subclass("m_cell")
    ms_para = mention_subclass("m_para")
    ms_cap = mention_subclass("m_cap")
    ms_sent = mention_subclass("m_sent")

    m_doc = MentionDocuments()
    m_sec = MentionSections()
    m_tab = MentionTables()
    m_fig = MentionFigures()
    m_cell = MentionCells()
    m_para = MentionParagraphs()
    m_cap = MentionCaptions()
    m_sent = MentionSentences()

    ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell]
    m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell]
    matchers = [DoNothingMatcher()] * 8

    mention_extractor = MentionExtractor(session,
                                         ms,
                                         m,
                                         matchers,
                                         parallelism=PARALLEL)

    mention_extractor.apply(docs)

    assert session.query(ms_doc).count() == 1
    assert session.query(ms_cap).count() == 2
    assert session.query(ms_sec).count() == 5
    assert session.query(ms_tab).count() == 2
    assert session.query(ms_fig).count() == 2
    assert session.query(ms_para).count() == 30
    assert session.query(ms_sent).count() == 35
    assert session.query(ms_cell).count() == 21
示例#6
0
def test_meta_connection_strings():
    """Simple sanity checks for validating postgres connection strings."""

    with pytest.raises(ValueError):
        Meta.init("postgresql" + DB).Session()

    with pytest.raises(ValueError):
        Meta.init("sqlite://somethingsilly" + DB).Session()

    with pytest.raises(ValueError):
        Meta.init("postgresql://somethingsilly:5432/").Session()

    Meta.init("postgresql://localhost:5432/" + DB).Session()
    assert Meta.DBNAME == DB
    Meta.init("postgresql://localhost:5432/" + "cand_test").Session()
    assert Meta.DBNAME == "cand_test"
示例#7
0
def test_spacy_integration(caplog):
    """Run a simple e2e parse using spaCy as our parser.

    The point of this test is to actually use the DB just as would be
    done in a notebook by a user.
    """
    #  caplog.set_level(logging.INFO)
    logger = logging.getLogger(__name__)

    PARALLEL = 2  # Travis only gives 2 cores

    session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session()

    docs_path = 'tests/data/html_simple/'
    pdf_path = 'tests/data/pdf_simple/'

    max_docs = 2
    doc_preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = OmniParser(
        structural=True, lingual=True, visual=False, pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    docs = session.query(Document).order_by(Document.name).all()

    for doc in docs:
        logger.info("Doc: {}".format(doc.name))
        for phrase in doc.phrases:
            logger.info("  Phrase: {}".format(phrase.text))

    assert session.query(Document).count() == 2
    assert session.query(Phrase).count() == 81
示例#8
0
def parse(html_location, database):
    """
    Wrapper function for calling Fonduer parser.
    :param html_location: HTML files generated by ``parse_preprocess.py``.
    :param database: db connection string.
    """
    session = Meta.init(database).Session()
    doc_preprocessor = HTMLDocPreprocessor(html_location)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor)
    def load_context(self, context: PythonModelContext) -> None:
        # Configure logging for Fonduer
        init_logging(log_dir="logs")
        logger.info("loading context")

        pyfunc_conf = _get_flavor_configuration(model_path=self.model_path,
                                                flavor_name=pyfunc.FLAVOR_NAME)
        conn_string = pyfunc_conf.get(CONN_STRING, None)
        if conn_string is None:
            raise RuntimeError("conn_string is missing from MLmodel file.")
        self.parallel = pyfunc_conf.get(PARALLEL, 1)
        session = Meta.init(conn_string).Session()

        logger.info("Getting parser")
        self.corpus_parser = self._get_parser(session)
        logger.info("Getting mention extractor")
        self.mention_extractor = self._get_mention_extractor(session)
        logger.info("Getting candidate extractor")
        self.candidate_extractor = self._get_candidate_extractor(session)
        candidate_classes = self.candidate_extractor.candidate_classes

        self.model_type = pyfunc_conf.get(MODEL_TYPE, "discriminative")
        if self.model_type == "discriminative":
            self.featurizer = Featurizer(session, candidate_classes)
            with open(os.path.join(self.model_path, "feature_keys.pkl"),
                      "rb") as f:
                key_names = pickle.load(f)
            self.featurizer.drop_keys(key_names)
            self.featurizer.upsert_keys(key_names)

            disc_model = LogisticRegression()

            # Workaround to https://github.com/HazyResearch/fonduer/issues/208
            checkpoint = torch.load(
                os.path.join(self.model_path, "best_model.pt"))
            disc_model.settings = checkpoint["config"]
            disc_model.cardinality = checkpoint["cardinality"]
            disc_model._build_model()

            disc_model.load(model_file="best_model.pt",
                            save_dir=self.model_path)
            self.disc_model = disc_model
        else:
            self.labeler = Labeler(session, candidate_classes)
            with open(os.path.join(self.model_path, "labeler_keys.pkl"),
                      "rb") as f:
                key_names = pickle.load(f)
            self.labeler.drop_keys(key_names)
            self.labeler.upsert_keys(key_names)

            self.gen_models = [
                LabelModel.load(
                    os.path.join(self.model_path, _.__name__ + ".pkl"))
                for _ in candidate_classes
            ]
示例#10
0
def test_parse_error_doc_skipping():
    """Test skipping of faulty htmls."""
    faulty_doc_path = "tests/data/html_faulty/ext_diseases_missing_table_tag.html"
    preprocessor = HTMLDocPreprocessor(faulty_doc_path)
    session = Meta.init(CONN_STRING).Session()
    corpus_parser = Parser(session)
    corpus_parser.apply(preprocessor)
    # This returns documents that apply() was called on
    assert corpus_parser.last_docs == {"ext_diseases_missing_table_tag"}
    # This returns only documents that are successfully parsed.
    assert corpus_parser.get_last_documents() == []
示例#11
0
def test_parse_structure(caplog):
    """Unit test of OmniParserUDF.parse_structure().

    This only tests the structural parse of the document.
    """
    caplog.set_level(logging.INFO)
    logger = logging.getLogger(__name__)
    session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session()

    max_docs = 1
    docs_path = 'tests/data/html_simple/md.html'
    pdf_path = 'tests/data/pdf_simple/md.pdf'

    # Preprocessor for the Docs
    preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    # Grab one document, text tuple from the preprocessor
    doc, text = next(preprocessor.generate())
    logger.info("    Text: {}".format(text))

    # Create an OmniParserUDF
    omni_udf = OmniParserUDF(
        True,  # structural
        ["style"],  # blacklist
        ["span", "br"],  # flatten
        '',  # flatten delim
        True,  # lingual
        True,  # strip
        [(u'[\u2010\u2011\u2012\u2013\u2014\u2212\uf02d]', '-')],  # replace
        True,  # tabular
        True,  # visual
        pdf_path,  # pdf path
        Spacy())  # lingual parser

    # Grab the phrases parsed by the OmniParser
    phrases = list(omni_udf.parse_structure(doc, text))

    logger.warning("Doc: {}".format(doc))
    for phrase in phrases:
        logger.warning("    Phrase: {}".format(phrase.text))

    header = phrases[0]
    # Test structural attributes
    assert header.xpath == '/html/body/h1'
    assert header.html_tag == 'h1'
    assert header.html_attrs == ['id=sample-markdown']

    # Test the unicode parse of delta
    assert (phrases[-1].text == "δ13Corg")

    # phrases expected in the "md" document.
    assert len(phrases) == 45
示例#12
0
def parse(html_location, database, parallelism=1):
    """
    Wrapper function for calling Fonduer parser
    :param html_location: HTML files generated by ``parse_preprocess.py``
    :param database: DB connection string
    :param parallelism: Number of cores to use
    """
    session = Meta.init(database).Session()
    doc_preprocessor = HTMLDocPreprocessor(html_location)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           parallelism=parallelism)
    corpus_parser.apply(doc_preprocessor)
示例#13
0
def test_parse_style(caplog):
    """Test style tag parsing."""
    caplog.set_level(logging.INFO)
    logger = logging.getLogger(__name__)
    session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session()

    # SpaCy on mac has issue on parallel parseing
    if os.name == "posix":
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 1
    docs_path = "tests/data/html_extended/ext_diseases.html"
    pdf_path = "tests/data/pdf_extended/ext_diseases.pdf"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    # Create an Parser and parse the md document
    omni = Parser(structural=True, lingual=True, visual=True, pdf_path=pdf_path)
    omni.apply(preprocessor, parallelism=PARALLEL)

    # Grab the document
    doc = session.query(Document).order_by(Document.name).all()[0]

    # Grab the sentences parsed by the Parser
    sentences = list(session.query(Sentence).order_by(Sentence.position).all())

    logger.warning("Doc: {}".format(doc))
    for i, sentence in enumerate(sentences):
        logger.warning("    Sentence[{}]: {}".format(i, sentence.html_attrs))

    # sentences for testing
    sub_sentences = [
        {
            "index": 6,
            "attr": [
                "class=col-header",
                "hobbies=work:hard;play:harder",
                "type=phenotype",
                "style=background: #f1f1f1; color: aquamarine; font-size: 18px;",
            ],
        },
        {"index": 9, "attr": ["class=row-header", "style=background: #f1f1f1;"]},
        {"index": 11, "attr": ["class=cell", "style=text-align: center;"]},
    ]

    # Assertions
    assert all(sentences[p["index"]].html_attrs == p["attr"] for p in sub_sentences)
示例#14
0
def test_parse_document_md(caplog):
    """Unit test of OmniParser on a single document.

    This tests both the structural and visual parse of the document. This
    also serves as a test of single-threaded parsing.
    """
    logger = logging.getLogger(__name__)
    session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session()

    PARALLEL = 1
    max_docs = 2
    docs_path = 'tests/data/html_simple/'
    pdf_path = 'tests/data/pdf_simple/'

    # Preprocessor for the Docs
    preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    # Create an OmniParser and parse the md document
    omni = OmniParser(structural=True,
                      lingual=True,
                      visual=True,
                      pdf_path=pdf_path)
    omni.apply(preprocessor, parallelism=PARALLEL)

    # Grab the md document
    doc = session.query(Document).order_by(Document.name).all()[1]

    logger.info("Doc: {}".format(doc))
    for phrase in doc.phrases:
        logger.info("    Phrase: {}".format(phrase.text))

    header = doc.phrases[0]
    # Test structural attributes
    assert header.xpath == '/html/body/h1'
    assert header.html_tag == 'h1'
    assert header.html_attrs == ['id=sample-markdown']

    # Test visual attributes
    assert header.page == [1, 1]
    assert header.top == [35, 35]
    assert header.bottom == [61, 61]
    assert header.right == [111, 231]
    assert header.left == [35, 117]

    # Test lingual attributes
    assert header.ner_tags == ['O', 'O']
    assert header.dep_labels == ['compound', 'ROOT']

    # 44 phrases expected in the "md" document.
    assert len(doc.phrases) == 45
示例#15
0
def test_visualizer(caplog):
    from fonduer.utils.visualizer import Visualizer  # noqa
    """Unit test of visualizer using the md document.
    """
    caplog.set_level(logging.INFO)
    session = Meta.init("postgresql://localhost:5432/" + ATTRIBUTE).Session()

    PARALLEL = 1
    max_docs = 1
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/md.pdf"

    # Preprocessor for the Docs
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    # Create an Parser and parse the md document
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    # Grab the md document
    doc = session.query(Document).order_by(Document.name).all()[0]
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor = MentionExtractor(session, [Org], [organization_ngrams],
                                         [organization_matcher])

    mention_extractor.apply([doc], parallelism=PARALLEL)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor = CandidateExtractor(session, [Organization])

    candidate_extractor.apply([doc], split=0, parallelism=PARALLEL)

    cands = session.query(Organization).filter(Organization.split == 0).all()

    # Test visualizer
    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)
    vis.display_candidates([cands[0]])
示例#16
0
def test_parse_document_diseases(caplog):
    """Unit test of OmniParser on a single document.

    This tests both the structural and visual parse of the document.
    """
    caplog.set_level(logging.INFO)
    logger = logging.getLogger(__name__)
    session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session()

    PARALLEL = 2
    max_docs = 2
    docs_path = 'tests/data/html_simple/'
    pdf_path = 'tests/data/pdf_simple/'

    # Preprocessor for the Docs
    preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    # Create an OmniParser and parse the diseases document
    omni = OmniParser(structural=True,
                      lingual=True,
                      visual=True,
                      pdf_path=pdf_path)
    omni.apply(preprocessor, parallelism=PARALLEL)

    # Grab the diseases document
    doc = session.query(Document).order_by(Document.name).all()[0]

    logger.info("Doc: {}".format(doc))
    for phrase in doc.phrases:
        logger.info("    Phrase: {}".format(phrase.text))

    phrase = sorted(doc.phrases)[11]
    logger.info("  {}".format(phrase))
    # Test structural attributes
    assert phrase.xpath == '/html/body/table[1]/tbody/tr[3]/td[1]/p'
    assert phrase.html_tag == 'p'
    assert phrase.html_attrs == ['class=s6', 'style=padding-top: 1pt']

    # Test visual attributes
    assert phrase.page == [1, 1, 1]
    assert phrase.top == [342, 296, 356]
    assert phrase.left == [318, 369, 318]

    # Test lingual attributes
    assert phrase.ner_tags == ['O', 'O', 'GPE']
    assert phrase.dep_labels == ['ROOT', 'prep', 'pobj']

    # 44 phrases expected in the "diseases" document.
    assert len(doc.phrases) == 36
示例#17
0
def test_simple_tokenizer(caplog):
    """Unit test of Parser on a single document with lingual features off."""
    caplog.set_level(logging.INFO)
    logger = logging.getLogger(__name__)
    session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session()

    # SpaCy on mac has issue on parallel parseing
    if os.name == "posix":
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 2
    docs_path = "tests/data/html_simple/"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    parser = Parser(structural=True,
                    lingual=False,
                    visual=True,
                    pdf_path=pdf_path)
    parser.apply(preprocessor, parallelism=PARALLEL)

    doc = session.query(Document).order_by(Document.name).all()[1]

    logger.info("Doc: {}".format(doc))
    for i, sentence in enumerate(doc.sentences):
        logger.info("    Sentence[{}]: {}".format(i, sentence.text))

    header = sorted(doc.sentences, key=lambda x: x.position)[0]
    # Test structural attributes
    assert header.xpath == "/html/body/h1"
    assert header.html_tag == "h1"
    assert header.html_attrs == ["id=sample-markdown"]

    # Test lingual attributes
    assert header.ner_tags == ["", ""]
    assert header.dep_labels == ["", ""]
    assert header.dep_parents == [0, 0]
    assert header.lemmas == ["", ""]
    assert header.pos_tags == ["", ""]

    assert len(doc.sentences) == 44
示例#18
0
def test_meta_connection_strings(database_session):
    """Simple sanity checks for validating postgres connection strings."""
    with pytest.raises(ValueError):
        Meta.init("postgresql" + DB).Session()

    with pytest.raises(ValueError):
        Meta.init("sqlite://somethingsilly" + DB).Session()

    with pytest.raises(OperationalError):
        Meta.init("postgresql://somethingsilly:5432/").Session()

    session = Meta.init("postgresql://localhost:5432/" + DB).Session()
    engine = session.get_bind()
    session.close()
    engine.dispose()
    assert Meta.DBNAME == DB
示例#19
0
def test_ngrams(caplog):
    """Test ngram limits in mention extraction"""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/lincoln_short.html"

    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 503
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    Person = mention_subclass("Person")
    person_ngrams = MentionNgrams(n_max=3)
    person_matcher = PersonMatcher()

    mention_extractor = MentionExtractor(
        session, [Person], [person_ngrams], [person_matcher]
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Person).count() == 118
    mentions = session.query(Person).all()
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0

    # Test for unigram exclusion
    person_ngrams = MentionNgrams(n_min=2, n_max=3)
    mention_extractor = MentionExtractor(
        session, [Person], [person_ngrams], [person_matcher]
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)
    assert session.query(Person).count() == 69
    mentions = session.query(Person).all()
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
示例#20
0
def test_preprocessor_parse_file_called_once(mocker):
    """Test if DocPreprocessor._parse_file is called only once during parser.apply."""
    max_docs = 1
    session = Meta.init(CONN_STRING).Session()
    docs_path = "tests/data/html/"
    # Set up preprocessor, parser, and spy on preprocessor
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    spy = mocker.spy(doc_preprocessor, "_parse_file")
    corpus_parser = Parser(session)

    # Check if udf.last_docs is empty.
    assert len(corpus_parser.get_last_documents()) == 0

    # Parsing
    corpus_parser.apply(doc_preprocessor)

    # Check if udf.last_docs is correctly updated.
    assert len(corpus_parser.get_last_documents()) == max_docs
    # doc_preprocessor._parse_file should be called only once (#434).
    spy.assert_called_once()
示例#21
0
def test_row_col_ngram_extraction(caplog):
    """Test whether row/column ngrams list is empty, if mention is not in a table."""
    caplog.set_level(logging.INFO)
    PARALLEL = 1
    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()
    docs_path = "tests/data/pure_html/lincoln_short.html"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    place_ngrams = MentionNgramsTemp(n_max=4)
    Place = mention_subclass("Place")

    def get_row_and_column_ngrams(mention):
        row_ngrams = list(get_row_ngrams(mention))
        col_ngrams = list(get_col_ngrams(mention))
        if not mention.sentence.is_tabular():
            assert len(row_ngrams) == 1 and row_ngrams[0] is None
            assert len(col_ngrams) == 1 and col_ngrams[0] is None
        else:
            assert not any(x is None for x in row_ngrams)
            assert not any(x is None for x in col_ngrams)
        if "birth_place" in row_ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(func=get_row_and_column_ngrams)
    mention_extractor = MentionExtractor(
        session, [Place], [place_ngrams], [birthplace_matcher]
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)
示例#22
0
def test_simple_tokenizer(caplog):
    """Unit test of OmniParser on a single document with lingual features off.
    """
    caplog.set_level(logging.INFO)
    logger = logging.getLogger(__name__)
    session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session()

    PARALLEL = 2
    max_docs = 2
    docs_path = 'tests/data/html_simple/'
    pdf_path = 'tests/data/pdf_simple/'

    # Preprocessor for the Docs
    preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    omni = OmniParser(
        structural=True, lingual=False, visual=True, pdf_path=pdf_path)
    omni.apply(preprocessor, parallelism=PARALLEL)

    doc = session.query(Document).order_by(Document.name).all()[1]

    logger.info("Doc: {}".format(doc))
    for i, phrase in enumerate(doc.phrases):
        logger.info("    Phrase[{}]: {}".format(i, phrase.text))

    header = doc.phrases[0]
    # Test structural attributes
    assert header.xpath == '/html/body/h1'
    assert header.html_tag == 'h1'
    assert header.html_attrs == ['id=sample-markdown']

    # Test lingual attributes
    assert header.ner_tags == ['', '']
    assert header.dep_labels == ['', '']
    assert header.dep_parents == [0, 0]
    assert header.lemmas == ['', '']
    assert header.pos_tags == ['', '']

    assert len(doc.phrases) == 44
示例#23
0
def test_warning_on_incorrect_filename(caplog):
    """Test that a warning is issued on invalid pdf."""
    caplog.set_level(logging.INFO)
    session = Meta.init("postgres://localhost:5432/" + ATTRIBUTE).Session()

    PARALLEL = 1
    docs_path = "tests/data/html_simple/md_para.html"
    pdf_path = "tests/data/html_simple/md_para.html"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)

    # Create an Parser and parse the md document
    parser = Parser(structural=True,
                    tabular=True,
                    lingual=True,
                    visual=True,
                    pdf_path=pdf_path)
    with pytest.warns(RuntimeWarning):
        parser.apply(preprocessor, parallelism=PARALLEL)

    assert session.query(Document).count() == 1
示例#24
0
def test_parse_style(caplog):
    """Test style tag parsing."""
    caplog.set_level(logging.INFO)
    logger = logging.getLogger(__name__)
    session = Meta.init('postgres://localhost:5432/' + ATTRIBUTE).Session()

    max_docs = 1
    docs_path = 'tests/data/html_extended/ext_diseases.html'
    pdf_path = 'tests/data/pdf_extended/ext_diseases.pdf'

    # Preprocessor for the Docs
    preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    # Grab the document, text tuple from the preprocessor
    doc, text = next(preprocessor.generate())
    logger.info("    Text: {}".format(text))

    # Create an OmniParserUDF
    omni_udf = OmniParserUDF(
        True,           # structural
        [],             # blacklist, empty so that style is not blacklisted
        ["span", "br"],  # flatten
        '',             # flatten delim
        True,           # lingual
        True,           # strip
        [],             # replace
        True,           # tabular
        True,           # visual
        pdf_path,       # pdf path
        Spacy())        # lingual parser

    # Grab the phrases parsed by the OmniParser
    phrases = list(omni_udf.parse_structure(doc, text))

    logger.warning("Doc: {}".format(doc))
    for phrase in phrases:
        logger.warning("    Phrase: {}".format(phrase.html_attrs))

    # Phrases for testing
    sub_phrases = [
        {
            'index': 7,
            'attr': [
                'class=col-header',
                'hobbies=work:hard;play:harder',
                'type=phenotype',
                'style=background: #f1f1f1; color: aquamarine; font-size: 18px;'
            ]
        },
        {
            'index': 10,
            'attr': ['class=row-header', 'style=background: #f1f1f1;']
        },
        {
            'index': 12,
            'attr': ['class=cell', 'style=text-align: center;']
        }
    ]
    
    # Assertions
    assert(all(phrases[p['index']].html_attrs == p['attr'] for p in sub_phrases))
示例#25
0
def test_mention_longest_match(caplog):
    """Test longest match filtering in mention extraction."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parsing
    PARALLEL = 1

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/lincoln_short.html"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    docs = session.query(Document).order_by(Document.name).all()
    # Mention Extraction
    name_ngrams = MentionNgramsPart(n_max=3)
    place_ngrams = MentionNgramsTemp(n_max=4)

    Name = mention_subclass("Name")
    Place = mention_subclass("Place")

    def is_birthplace_table_row(mention):
        if not mention.sentence.is_tabular():
            return False
        ngrams = get_row_ngrams(mention, lower=True)
        if "birth_place" in ngrams:
            return True
        else:
            return False

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=False
    )
    mention_extractor = MentionExtractor(
        session,
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)
    mentions = session.query(Place).all()
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" in mention_spans
    assert len(mention_spans) == 23

    birthplace_matcher = LambdaFunctionMatcher(
        func=is_birthplace_table_row, longest_match_only=True
    )
    mention_extractor = MentionExtractor(
        session,
        [Name, Place],
        [name_ngrams, place_ngrams],
        [PersonMatcher(), birthplace_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)
    mentions = session.query(Place).all()
    mention_spans = [x.context.get_span() for x in mentions]
    assert "Sinking Spring Farm" in mention_spans
    assert "Farm" not in mention_spans
    assert len(mention_spans) == 4
示例#26
0
def test_cand_gen(caplog):
    """Test extracting candidates from mentions from documents."""
    caplog.set_level(logging.INFO)

    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using two cores.")
        PARALLEL = 2  # Travis only gives 2 cores

    def do_nothing_matcher(fig):
        return True

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 70
    assert session.query(Volt).count() == 33
    assert session.query(Temp).count() == 23
    assert session.query(Fig).count() == 31
    part = session.query(Part).order_by(Part.id).all()[0]
    volt = session.query(Volt).order_by(Volt.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Volt: {volt.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp, PartVolt],
            throttlers=[
                temp_throttler,
                volt_throttler,
                volt_throttler,
            ],  # Fail, mismatched arity
        )

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],
        )

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt]
    )  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1610
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3920
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(PartVolt).count() == 0

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, None]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3742
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 1993
    assert session.query(Candidate).count() == 3425
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].volts) == 33
    assert len(docs[0].temps) == 23

    # Test that deletion of a Candidate does not delete the Mention
    session.query(PartTemp).delete(synchronize_session="fetch")
    assert session.query(PartTemp).count() == 0
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Test deletion of Candidate if Mention is deleted
    assert session.query(PartVolt).count() == 1993
    assert session.query(Volt).count() == 33
    session.query(Volt).delete(synchronize_session="fetch")
    assert session.query(Volt).count() == 0
    assert session.query(PartVolt).count() == 0
示例#27
0
def test_cand_gen_cascading_delete(caplog):
    """Test cascading the deletion of candidates."""
    caplog.set_level(logging.INFO)

    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using two cores.")
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )
    mention_extractor.clear_all()
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 93
    assert session.query(Part).count() == 70
    assert session.query(Temp).count() == 23
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(Candidate).count() == 1432
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 23

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(Candidate).count() == 1431
    assert session.query(PartTemp).count() == 1431

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
示例#28
0
def test_cand_gen_cascading_delete():
    """Test cascading the deletion of candidates."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )
    mention_extractor.clear_all()
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 93
    assert session.query(Part).count() == 70
    assert session.query(Temp).count() == 23
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(Candidate).count() == 1432
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 23

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(Candidate).count() == 1431
    assert session.query(PartTemp).count() == 1431

    # Test that deletion of a Candidate does not delete the Mention
    x = session.query(PartTemp).first()
    session.query(PartTemp).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(PartTemp).count() == 1430
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
示例#29
0
def test_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    PARALLEL = 1

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert docs[0].name == "112823"
    assert session.query(Part).count() == 58
    assert session.query(Temp).count() == 16
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    n_cands = session.query(PartTemp).count()

    # Featurization based on default feature library
    featurizer = Featurizer(session, [PartTemp])

    # Test that featurization default feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_feats = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_w_customized_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_textual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_tabular_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_structural_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_visual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands
示例#30
0
def test_incremental(caplog):
    """Run an end-to-end test on incremental additions."""
    caplog.set_level(logging.INFO)

    PARALLEL = 1

    max_docs = 1

    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/dtc114w.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()
    last_docs = corpus_parser.get_documents()

    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 8

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70
    assert session.query(Candidate).count() == 70

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 70

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 70
    assert session.query(FeatureKey).count() == 512

    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (70, 512)
    assert len(featurizer.get_keys()) == 512

    # Test Dropping FeatureKey
    featurizer.drop_keys(["CORE_e1_LENGTH_1"])
    assert session.query(FeatureKey).count() == 512

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == 70

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (70, 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)

    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1502

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == 1502
    assert session.query(FeatureKey).count() == 2573
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (1502, 2573)
    assert len(featurizer.get_keys()) == 2573

    # Update Labels
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 1502
    assert session.query(LabelKey).count() == 6
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (1502, 6)

    # Test clear
    featurizer.clear(train=True)
    assert session.query(FeatureKey).count() == 0