def get_subclasses(experiment): # 1.) Mention subclasses Data = mention_subclass("Data") Row = mention_subclass("Row") Col = mention_subclass("Col") # 2.) Mention spaces data_ngrams = MentionSentences() # MentionNgrams(n_max=3) row_ngrams = MentionSentences() # MentionNgrams(n_min=1, n_max=8) col_ngrams = MentionSentences() # MentionNgrams(n_min=1, n_max=8) # 3.) Matchers data_regex_matcher = RegexMatchSpan(rgx=r"[0-9-,.%$#]+( to | )?[0-9-,.%$#]*|^x$", longest_match_only=True) data_label_matcher = LambdaFunctionMatcher(func=get_label_matcher("Data", experiment)) data_matcher = Intersect(data_regex_matcher, data_label_matcher) row_regex_matcher = RegexMatchSpan(rgx=r"^.*$", longest_match_only=True) row_label_matcher = LambdaFunctionMatcher(func=get_label_matcher("Header", experiment)) row_matcher = Intersect(row_regex_matcher, row_label_matcher) col_regex_matcher = RegexMatchSpan(rgx=r"^.*$", longest_match_only=True) col_label_matcher = LambdaFunctionMatcher(func=get_label_matcher("Header", experiment)) col_matcher = Intersect(col_regex_matcher, col_label_matcher) # 4.) Candidate classes RowCandidate = candidate_subclass("RowCandidate", [Data, Row]) ColCandidate = candidate_subclass("ColCandidate", [Data, Col]) # 5.) Throttlers mention_classes = [Data, Row, Col] mention_spaces = [data_ngrams, row_ngrams, col_ngrams] matchers = [data_matcher, row_matcher, col_matcher] candidate_classes = [RowCandidate, ColCandidate] throttlers = [row_filter, col_filter] return (mention_classes, mention_spaces, matchers, candidate_classes, throttlers)
def test_multimodal_cand(caplog): """Test multimodal candidate generation""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/radiology.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 35 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor = MentionExtractor(session, ms, m, matchers, parallelism=PARALLEL) mention_extractor.apply(docs) assert session.query(ms_doc).count() == 1 assert session.query(ms_cap).count() == 2 assert session.query(ms_sec).count() == 5 assert session.query(ms_tab).count() == 2 assert session.query(ms_fig).count() == 2 assert session.query(ms_para).count() == 30 assert session.query(ms_sent).count() == 35 assert session.query(ms_cell).count() == 21
def test_multimodal_cand(): """Test multimodal candidate generation""" file_name = "radiology" docs_path = f"tests/data/pure_html/{file_name}.html" doc = parse_doc(docs_path, file_name) assert len(doc.sentences) == 35 # Mention Extraction ms_doc = mention_subclass("m_doc") ms_sec = mention_subclass("m_sec") ms_tab = mention_subclass("m_tab") ms_fig = mention_subclass("m_fig") ms_cell = mention_subclass("m_cell") ms_para = mention_subclass("m_para") ms_cap = mention_subclass("m_cap") ms_sent = mention_subclass("m_sent") m_doc = MentionDocuments() m_sec = MentionSections() m_tab = MentionTables() m_fig = MentionFigures() m_cell = MentionCells() m_para = MentionParagraphs() m_cap = MentionCaptions() m_sent = MentionSentences() ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell] m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell] matchers = [DoNothingMatcher()] * 8 mention_extractor_udf = MentionExtractorUDF(ms, m, matchers) doc = mention_extractor_udf.apply(doc) assert len(doc.m_docs) == 1 assert len(doc.m_caps) == 2 assert len(doc.m_secs) == 5 assert len(doc.m_tabs) == 2 assert len(doc.m_figs) == 2 assert len(doc.m_paras) == 30 assert len(doc.m_sents) == 35 assert len(doc.m_cells) == 21
from fonduer.candidates import MentionSentences mention_sentence = MentionSentences() def get_mention_spaces(): return [mention_sentence]