def test_multimodal_cand(caplog): """Test multimodal candidate generation""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/radiology.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 35 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor = MentionExtractor(session, ms, m, matchers, parallelism=PARALLEL) mention_extractor.apply(docs) assert session.query(ms_doc).count() == 1 assert session.query(ms_cap).count() == 2 assert session.query(ms_sec).count() == 5 assert session.query(ms_tab).count() == 2 assert session.query(ms_fig).count() == 2 assert session.query(ms_para).count() == 30 assert session.query(ms_sent).count() == 35 assert session.query(ms_cell).count() == 21
def test_multimodal_cand(): """Test multimodal candidate generation""" file_name = "radiology" docs_path = f"tests/data/pure_html/{file_name}.html" doc = parse_doc(docs_path, file_name) assert len(doc.sentences) == 35 # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor_udf = MentionExtractorUDF(ms, m, matchers) doc = mention_extractor_udf.apply(doc) assert len(doc.m_docs) == 1 assert len(doc.m_caps) == 2 assert len(doc.m_secs) == 5 assert len(doc.m_tabs) == 2 assert len(doc.m_figs) == 2 assert len(doc.m_paras) == 30 assert len(doc.m_sents) == 35 assert len(doc.m_cells) == 21