def test_ngrams(): """Test ngram limits in mention extraction""" file_name = "lincoln_short" docs_path = f"tests/data/pure_html/{file_name}.html" doc = parse_doc(docs_path, file_name) # Mention Extraction Person = mention_subclass("Person") person_ngrams = MentionNgrams(n_max=3) person_matcher = PersonMatcher() mention_extractor_udf = MentionExtractorUDF( [Person], [person_ngrams], [person_matcher] ) doc = mention_extractor_udf.apply(doc) assert len(doc.persons) == 118 mentions = doc.persons assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0 # Test for unigram exclusion for mention in doc.persons[:]: doc.persons.remove(mention) assert len(doc.persons) == 0 person_ngrams = MentionNgrams(n_min=2, n_max=3) mention_extractor_udf = MentionExtractorUDF( [Person], [person_ngrams], [person_matcher] ) doc = mention_extractor_udf.apply(doc) assert len(doc.persons) == 69 mentions = doc.persons assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
def __init__( self, n_min: int = 1, n_max: int = 5, split_tokens: Collection[str] = [], types: Optional[str] = None ) -> None: """Initialize MentionNgrams.""" MentionNgrams.__init__(self, n_min=n_min, n_max=n_max, split_tokens=split_tokens) if types is not None: self.types = [t.strip().lower() for t in types] else: self.types = None
def __init__(self, parts_by_doc=None, n_max=3, expand=True, split_tokens=None): """MentionNgrams specifically for transistor parts. :param parts_by_doc: a dictionary d where d[document_name.upper()] = [partA, partB, ...] """ MentionNgrams.__init__(self, n_max=n_max, split_tokens=None) self.parts_by_doc = parts_by_doc self.expander = expand_part_range if expand else (lambda x: [x])
def test_visualizer(): from fonduer.utils.visualizer import Visualizer # noqa """Unit test of visualizer using the md document. """ docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/md.pdf" # Grab the md document doc = parse_doc(docs_path, "md", pdf_path) assert doc.name == "md" organization_ngrams = MentionNgrams(n_max=1) Org = mention_subclass("Org") organization_matcher = OrganizationMatcher() mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams], [organization_matcher]) doc = mention_extractor_udf.apply(doc) Organization = candidate_subclass("Organization", [Org]) candidate_extractor_udf = CandidateExtractorUDF([Organization], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) cands = doc.organizations # Test visualizer pdf_path = "tests/data/pdf_simple" vis = Visualizer(pdf_path) vis.display_candidates([cands[0]])
def apply(self, doc): for ts in MentionNgrams.apply(self, doc): if ts.get_span().endswith(".0"): value = ts.get_span()[:-2] yield TemporaryImplicitSpanMention( sentence=ts.sentence, char_start=ts.char_start, char_end=ts.char_end, expander_key="volt_expander", position=0, text=value, words=[value], lemmas=[value], pos_tags=[ts.get_attrib_tokens("pos_tags")[-1]], ner_tags=[ts.get_attrib_tokens("ner_tags")[-1]], dep_parents=[ts.get_attrib_tokens("dep_parents")[-1]], dep_labels=[ts.get_attrib_tokens("dep_labels")[-1]], page=[ts.get_attrib_tokens("page")[-1]] if ts.sentence.is_visual() else [None], top=[ts.get_attrib_tokens("top")[-1]] if ts.sentence.is_visual() else [None], left=[ts.get_attrib_tokens("left")[-1]] if ts.sentence.is_visual() else [None], bottom=[ts.get_attrib_tokens("bottom")[-1]] if ts.sentence.is_visual() else [None], right=[ts.get_attrib_tokens("right")[-1]] if ts.sentence.is_visual() else [None], meta=None, ) else: yield ts
def test_ngrams(caplog): """Test ngram limits in mention extraction""" caplog.set_level(logging.INFO) PARALLEL = 4 max_docs = 1 session = Meta.init("postgresql://localhost:5432/" + DB).Session() docs_path = "tests/data/pure_html/lincoln_short.html" logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 503 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction Person = mention_subclass("Person") person_ngrams = MentionNgrams(n_max=3) person_matcher = PersonMatcher() mention_extractor = MentionExtractor( session, [Person], [person_ngrams], [person_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Person).count() == 118 mentions = session.query(Person).all() assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0 # Test for unigram exclusion person_ngrams = MentionNgrams(n_min=2, n_max=3) mention_extractor = MentionExtractor( session, [Person], [person_ngrams], [person_matcher] ) mention_extractor.apply(docs, parallelism=PARALLEL) assert session.query(Person).count() == 69 mentions = session.query(Person).all() assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0 assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
def apply(self, session, context): for ts in MentionNgrams.apply(self, session, context): m = re.match( r"^([\+\-\u2010\u2011\u2012\u2013\u2014\u2212\uf02d])?(\s*)(\d+)$", ts.get_span(), re.U, ) if m: if m.group(1) is None: temp = "" elif m.group(1) == "+": if m.group(2) != "": # If bigram '+ 150' is seen, accept the unigram '150', # not both continue temp = "" else: # m.group(1) is a type of negative sign # A bigram '- 150' is different from unigram '150', so we # keep the implicit '-150' temp = "-" temp += m.group(3) yield TemporaryImplicitSpan( sentence=ts.sentence, char_start=ts.char_start, char_end=ts.char_end, expander_key=u"temp_expander", position=0, text=temp, words=[temp], lemmas=[temp], pos_tags=[ts.get_attrib_tokens("pos_tags")[-1]], ner_tags=[ts.get_attrib_tokens("ner_tags")[-1]], dep_parents=[ts.get_attrib_tokens("dep_parents")[-1]], dep_labels=[ts.get_attrib_tokens("dep_labels")[-1]], page=[ts.get_attrib_tokens("page")[-1]] if ts.sentence.is_visual() else [None], top=[ts.get_attrib_tokens("top")[-1]] if ts.sentence.is_visual() else [None], left=[ts.get_attrib_tokens("left")[-1]] if ts.sentence.is_visual() else [None], bottom=[ts.get_attrib_tokens("bottom")[-1]] if ts.sentence.is_visual() else [None], right=[ts.get_attrib_tokens("right")[-1]] if ts.sentence.is_visual() else [None], meta=None, ) else: yield ts
def test_visualizer(caplog): from fonduer.utils.visualizer import Visualizer # noqa """Unit test of visualizer using the md document. """ caplog.set_level(logging.INFO) session = Meta.init("postgresql://localhost:5432/" + ATTRIBUTE).Session() PARALLEL = 1 max_docs = 1 docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/md.pdf" # Preprocessor for the Docs doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) # Create an Parser and parse the md document corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) # Grab the md document doc = session.query(Document).order_by(Document.name).all()[0] assert doc.name == "md" organization_ngrams = MentionNgrams(n_max=1) Org = mention_subclass("Org") organization_matcher = OrganizationMatcher() mention_extractor = MentionExtractor(session, [Org], [organization_ngrams], [organization_matcher]) mention_extractor.apply([doc], parallelism=PARALLEL) Organization = candidate_subclass("Organization", [Org]) candidate_extractor = CandidateExtractor(session, [Organization]) candidate_extractor.apply([doc], split=0, parallelism=PARALLEL) cands = session.query(Organization).filter(Organization.split == 0).all() # Test visualizer pdf_path = "tests/data/pdf_simple" vis = Visualizer(pdf_path) vis.display_candidates([cands[0]])
def mention_setup(): """Set up mentions.""" docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/" # Preprocessor for the Docs preprocessor = HTMLDocPreprocessor(docs_path) doc = next(preprocessor.__iter__()) # Create an Parser and parse the md document parser_udf = get_parser_udf( structural=True, tabular=True, lingual=True, visual=True, visual_parser=PdfVisualParser(pdf_path), language="en", ) doc = parser_udf.apply(doc) # Create 1-gram span mentions space = MentionNgrams(n_min=1, n_max=1) mentions = [tc for tc in space.apply(doc)] return mentions
def apply(self, session, context): for ts in MentionNgrams.apply(self, session, context): enumerated_parts = [ part.upper() for part in expand_part_range(ts.get_span()) ] parts = set(enumerated_parts) if self.parts_by_doc: possible_parts = self.parts_by_doc[ts.parent.document.name.upper()] for base_part in enumerated_parts: for part in possible_parts: if part.startswith(base_part) and len(base_part) >= 4: parts.add(part) for i, part in enumerate(parts): if " " in part: continue # it won't pass the part_matcher if part == ts.get_span(): yield ts else: yield TemporaryImplicitSpan( sentence=ts.sentence, char_start=ts.char_start, char_end=ts.char_end, expander_key=u"part_expander", position=i, text=part, words=[part], lemmas=[part], pos_tags=[ts.get_attrib_tokens("pos_tags")[0]], ner_tags=[ts.get_attrib_tokens("ner_tags")[0]], dep_parents=[ts.get_attrib_tokens("dep_parents")[0]], dep_labels=[ts.get_attrib_tokens("dep_labels")[0]], page=[min(ts.get_attrib_tokens("page"))] if ts.sentence.is_visual() else [None], top=[min(ts.get_attrib_tokens("top"))] if ts.sentence.is_visual() else [None], left=[max(ts.get_attrib_tokens("left"))] if ts.sentence.is_visual() else [None], bottom=[min(ts.get_attrib_tokens("bottom"))] if ts.sentence.is_visual() else [None], right=[max(ts.get_attrib_tokens("right"))] if ts.sentence.is_visual() else [None], meta=None, )
def test_visualizer(): """Unit test of visualizer using the md document.""" from fonduer.utils.visualizer import Visualizer, get_box # noqa docs_path = "tests/data/html_simple/md.html" pdf_path = "tests/data/pdf_simple/" # Grab the md document doc = parse_doc(docs_path, "md", pdf_path) assert doc.name == "md" organization_ngrams = MentionNgrams(n_max=1) Org = mention_subclass("Org") organization_matcher = OrganizationMatcher() mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams], [organization_matcher]) doc = mention_extractor_udf.apply(doc) Organization = candidate_subclass("Organization", [Org]) candidate_extractor_udf = CandidateExtractorUDF([Organization], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) # Take one candidate cand = doc.organizations[0] pdf_path = "tests/data/pdf_simple" vis = Visualizer(pdf_path) # Test bounding boxes boxes = [get_box(mention.context) for mention in cand.get_mentions()] for box in boxes: assert box.top <= box.bottom assert box.left <= box.right assert boxes == [ mention.context.get_bbox() for mention in cand.get_mentions() ] # Test visualizer vis.display_candidates([cand])
def apply(self, doc): for ts in MentionNgrams.apply(self, doc): m = re.match(r"^(±)?\s*(\d+)\s*(\.)?\s*(\d*)$", ts.get_span()) if m: # Handle case that random spaces are inserted (e.g. "± 2 . 3") temp = "" if m.group(1): temp += m.group(1) if m.group(2): temp += m.group(2) if m.group(3): temp += m.group(3) if m.group(4): temp += m.group(4) yield TemporaryImplicitSpanMention( sentence=ts.sentence, char_start=ts.char_start, char_end=ts.char_end, expander_key="opamp_exp", position=0, text=temp, words=[temp], lemmas=[temp], pos_tags=[ts.get_attrib_tokens("pos_tags")[-1]], ner_tags=[ts.get_attrib_tokens("ner_tags")[-1]], dep_parents=[ts.get_attrib_tokens("dep_parents")[-1]], dep_labels=[ts.get_attrib_tokens("dep_labels")[-1]], page=[ts.get_attrib_tokens("page")[-1]] if ts.sentence.is_visual() else [None], top=[ts.get_attrib_tokens("top")[-1]] if ts.sentence.is_visual() else [None], left=[ts.get_attrib_tokens("left")[-1]] if ts.sentence.is_visual() else [None], bottom=[ts.get_attrib_tokens("bottom")[-1]] if ts.sentence.is_visual() else [None], right=[ts.get_attrib_tokens("right")[-1]] if ts.sentence.is_visual() else [None], meta=None, ) else: yield ts
def apply(self, doc: Document) -> Iterator[TemporaryContext]: """Generate MentionNgrams from a Document by parsing all of its Sentences. :param doc: The ``Document`` to parse. :type doc: ``Document`` :raises TypeError: If the input doc is not of type ``Document``. """ if not isinstance(doc, Document): raise TypeError( "Input Contexts to MentionNgrams.apply() must be of type Document" ) for ts in MentionNgrams.apply(self, doc): yield ts for ts in MentionDocuments.apply(self, doc): yield ts for ts in MentionFigures.apply(self, doc): yield ts
def main( conn_string, stg_temp_min=False, stg_temp_max=False, polarity=False, ce_v_max=False, max_docs=float("inf"), parse=False, first_time=False, re_label=False, parallel=4, log_dir=None, verbose=False, ): if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) rel_list = [] if stg_temp_min: rel_list.append("stg_temp_min") if stg_temp_max: rel_list.append("stg_temp_max") if polarity: rel_list.append("polarity") if ce_v_max: rel_list.append("ce_v_max") session = Meta.init(conn_string).Session() # Parsing logger.info(f"Starting parsing...") start = timer() docs, train_docs, dev_docs, test_docs = parse_dataset(session, dirname, first_time=parse, parallel=parallel, max_docs=max_docs) end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") # Mention Extraction start = timer() mentions = [] ngrams = [] matchers = [] # Only do those that are enabled Part = mention_subclass("Part") part_matcher = get_matcher("part") part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) mentions.append(Part) ngrams.append(part_ngrams) matchers.append(part_matcher) if stg_temp_min: StgTempMin = mention_subclass("StgTempMin") stg_temp_min_matcher = get_matcher("stg_temp_min") stg_temp_min_ngrams = MentionNgramsTemp(n_max=2) mentions.append(StgTempMin) ngrams.append(stg_temp_min_ngrams) matchers.append(stg_temp_min_matcher) if stg_temp_max: StgTempMax = mention_subclass("StgTempMax") stg_temp_max_matcher = get_matcher("stg_temp_max") stg_temp_max_ngrams = MentionNgramsTemp(n_max=2) mentions.append(StgTempMax) ngrams.append(stg_temp_max_ngrams) matchers.append(stg_temp_max_matcher) if polarity: Polarity = mention_subclass("Polarity") polarity_matcher = get_matcher("polarity") polarity_ngrams = MentionNgrams(n_max=1) mentions.append(Polarity) ngrams.append(polarity_ngrams) matchers.append(polarity_matcher) if ce_v_max: CeVMax = mention_subclass("CeVMax") ce_v_max_matcher = get_matcher("ce_v_max") ce_v_max_ngrams = MentionNgramsVolt(n_max=1) mentions.append(CeVMax) ngrams.append(ce_v_max_ngrams) matchers.append(ce_v_max_matcher) mention_extractor = MentionExtractor(session, mentions, ngrams, matchers) if first_time: mention_extractor.apply(docs, parallelism=parallel) logger.info(f"Total Mentions: {session.query(Mention).count()}") logger.info(f"Total Part: {session.query(Part).count()}") if stg_temp_min: logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}") if stg_temp_max: logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}") if polarity: logger.info(f"Total Polarity: {session.query(Polarity).count()}") if ce_v_max: logger.info(f"Total CeVMax: {session.query(CeVMax).count()}") # Candidate Extraction cands = [] throttlers = [] if stg_temp_min: PartStgTempMin = candidate_subclass("PartStgTempMin", [Part, StgTempMin]) stg_temp_min_throttler = stg_temp_filter cands.append(PartStgTempMin) throttlers.append(stg_temp_min_throttler) if stg_temp_max: PartStgTempMax = candidate_subclass("PartStgTempMax", [Part, StgTempMax]) stg_temp_max_throttler = stg_temp_filter cands.append(PartStgTempMax) throttlers.append(stg_temp_max_throttler) if polarity: PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity]) polarity_throttler = polarity_filter cands.append(PartPolarity) throttlers.append(polarity_throttler) if ce_v_max: PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax]) ce_v_max_throttler = ce_v_max_filter cands.append(PartCeVMax) throttlers.append(ce_v_max_throttler) candidate_extractor = CandidateExtractor(session, cands, throttlers=throttlers) if first_time: for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=parallel) num_cands = session.query(Candidate).filter( Candidate.split == i).count() logger.info(f"Candidates in split={i}: {num_cands}") # These must be sorted for deterministic behavior. train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) end = timer() logger.warning( f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}") logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}") logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}") logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}") pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl") with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) # Check total recall for i, name in enumerate(rel_list): logger.info(name) result = entity_level_scores( candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc), attribute=name, corpus=dev_docs, ) logger.info(f"{name} Total Dev Recall: {result.rec:.3f}") result = entity_level_scores( candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc), attribute=name, corpus=test_docs, ) logger.info(f"{name} Total Test Recall: {result.rec:.3f}") # Featurization start = timer() cands = [] if stg_temp_min: cands.append(PartStgTempMin) if stg_temp_max: cands.append(PartStgTempMax) if polarity: cands.append(PartPolarity) if ce_v_max: cands.append(PartCeVMax) # Using parallelism = 1 for deterministic behavior. featurizer = Featurizer(session, cands, parallelism=1) if first_time: logger.info("Starting featurizer...") featurizer.apply(split=0, train=True) featurizer.apply(split=1) featurizer.apply(split=2) logger.info("Done") logger.info("Getting feature matrices...") if first_time: F_train = featurizer.get_feature_matrices(train_cands) F_dev = featurizer.get_feature_matrices(dev_cands) F_test = featurizer.get_feature_matrices(test_cands) end = timer() logger.warning( f"Featurization Time (min): {((end - start) / 60.0):.1f}") F_train_dict = {} F_dev_dict = {} F_test_dict = {} for idx, relation in enumerate(rel_list): F_train_dict[relation] = F_train[idx] F_dev_dict[relation] = F_dev[idx] F_test_dict[relation] = F_test[idx] pickle.dump(F_train_dict, open(os.path.join(dirname, "F_train_dict.pkl"), "wb")) pickle.dump(F_dev_dict, open(os.path.join(dirname, "F_dev_dict.pkl"), "wb")) pickle.dump(F_test_dict, open(os.path.join(dirname, "F_test_dict.pkl"), "wb")) else: F_train_dict = pickle.load( open(os.path.join(dirname, "F_train_dict.pkl"), "rb")) F_dev_dict = pickle.load( open(os.path.join(dirname, "F_dev_dict.pkl"), "rb")) F_test_dict = pickle.load( open(os.path.join(dirname, "F_test_dict.pkl"), "rb")) F_train = [] F_dev = [] F_test = [] for relation in rel_list: F_train.append(F_train_dict[relation]) F_dev.append(F_dev_dict[relation]) F_test.append(F_test_dict[relation]) logger.info("Done.") for i, cand in enumerate(cands): logger.info(f"{cand} Train shape: {F_train[i].shape}") logger.info(f"{cand} Test shape: {F_test[i].shape}") logger.info(f"{cand} Dev shape: {F_dev[i].shape}") logger.info("Labeling training data...") # Labeling start = timer() lfs = [] if stg_temp_min: lfs.append(stg_temp_min_lfs) if stg_temp_max: lfs.append(stg_temp_max_lfs) if polarity: lfs.append(polarity_lfs) if ce_v_max: lfs.append(ce_v_max_lfs) # Using parallelism = 1 for deterministic behavior. labeler = Labeler(session, cands, parallelism=1) if first_time: logger.info("Applying LFs...") labeler.apply(split=0, lfs=lfs, train=True) logger.info("Done...") # Uncomment if debugging LFs # load_transistor_labels(session, cands, ["ce_v_max"]) # labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel) # labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel) elif re_label: logger.info("Updating LFs...") labeler.update(split=0, lfs=lfs) logger.info("Done...") # Uncomment if debugging LFs # labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel) # labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel) logger.info("Getting label matrices...") L_train = labeler.get_label_matrices(train_cands) # Uncomment if debugging LFs # L_dev = labeler.get_label_matrices(dev_cands) # L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold") # # L_test = labeler.get_label_matrices(test_cands) # L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold") logger.info("Done.") if first_time: marginals_dict = {} for idx, relation in enumerate(rel_list): marginals_dict[relation] = generative_model(L_train[idx]) pickle.dump(marginals_dict, open(os.path.join(dirname, "marginals_dict.pkl"), "wb")) else: marginals_dict = pickle.load( open(os.path.join(dirname, "marginals_dict.pkl"), "rb")) marginals = [] for relation in rel_list: marginals.append(marginals_dict[relation]) end = timer() logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}") start = timer() word_counter = collect_word_counter(train_cands) # Training config config = { "meta_config": { "verbose": True, "seed": 17 }, "model_config": { "model_path": None, "device": 0, "dataparallel": False }, "learner_config": { "n_epochs": 5, "optimizer_config": { "lr": 0.001, "l2": 0.0 }, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": { "model/all/train/loss": "min" }, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.init(log_dir=Meta.log_path, config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=300, specials=specials) train_idxs = [] train_dataloader = [] for idx, relation in enumerate(rel_list): diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1) train_idxs.append(np.where(diffs > 1e-6)[0]) train_dataloader.append( EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset( relation, train_cands[idx], F_train[idx], emb_layer.word2id, marginals[idx], train_idxs[idx], ), split="train", batch_size=100, shuffle=True, )) num_feature_keys = len(featurizer.get_keys()) model = EmmentalModel(name=f"transistor_tasks") # List relation names, arities, list of classes tasks = create_task( rel_list, [2] * len(rel_list), num_feature_keys, [2] * len(rel_list), emb_layer, model="LogisticRegression", ) for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() # If given a list of multi, will train on multiple emmental_learner.learn(model, train_dataloader) # List of dataloader for each rlation for idx, relation in enumerate(rel_list): test_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, test_cands[idx], F_test[idx], emb_layer.word2id, 2), split="test", batch_size=100, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) best_result, best_b = scoring( relation, test_preds, test_cands[idx], test_docs, F_test[idx], parts_by_doc, num=100, ) # Dump CSV files for CE_V_MAX for digi-key analysis if relation == "ce_v_max": dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=100, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv") Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv") # Dump CSV files for POLARITY for digi-key analysis if relation == "polarity": dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=100, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv") Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv") end = timer() logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")
def test_feature_extraction(): """Test extracting candidates from mentions from documents.""" PARALLEL = 1 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgrams(n_max=1) temp_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") mention_extractor = MentionExtractor(session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert docs[0].name == "112823" assert session.query(Part).count() == 58 assert session.query(Temp).count() == 16 part = session.query(Part).order_by(Part.id).all()[0] temp = session.query(Temp).order_by(Temp.id).all()[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") # Candidate Extraction PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor(session, [PartTemp]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) n_cands = session.query(PartTemp).count() # Featurization based on default feature library featurizer = Featurizer(session, [PartTemp]) # Test that featurization default feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_feats = session.query(FeatureKey).count() featurizer.clear(train=True) # Example feature extractor def feat_ext(candidates): candidates = candidates if isinstance(candidates, list) else [candidates] for candidate in candidates: yield candidate.id, f"cand_id_{candidate.id}", 1 # Featurization with one extra feature extractor feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_w_customized_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization textual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_textual_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization tabular feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_tabular_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization structural feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_structural_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer = Featurizer(session, [PartTemp], feature_extractors=feature_extractors) # Test that featurization visual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_visual_features = session.query(FeatureKey).count() featurizer.clear(train=True) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features) assert n_default_w_customized_features == n_default_feats + n_cands
print("Sentences:", session.query(Sentence).count()) print("==============================") # Getting all documents parsed by Snorkel print("Getting documents and sentences...") docs = session.query(Document).all() #sents = session.query(Sentence).all() from fonduer.candidates import CandidateExtractor, MentionExtractor, MentionNgrams from fonduer.candidates.models import mention_subclass, candidate_subclass from fonduer.candidates.matchers import RegexMatchSpan, Union, LambdaFunctionMatcher from dataset_utils import price_match # Defining ngrams for candidates extraction_name = 'price' ngrams = MentionNgrams(n_max=5) # Define matchers matchers = LambdaFunctionMatcher(func=price_match) # Getting candidates PriceMention = mention_subclass("PriceMention") mention_extractor = MentionExtractor( session, [PriceMention], [ngrams], [matchers] ) mention_extractor.clear_all() mention_extractor.apply(docs, parallelism=parallelism) candidate_class = candidate_subclass("Price", [PriceMention]) candidate_extractor = CandidateExtractor(session, [candidate_class]) # Applying candidate extractors
print(f'Number of documents: {session.query(Document).count()}') #print(f'Number of sentences: {session.query(Sentence).count()}') print("==============================") # Getting all documents parsed by Snorkel print("Getting documents and sentences...") docs = session.query(Document).all() #sents = session.query(Sentence).all() from fonduer.candidates import CandidateExtractor, MentionExtractor, MentionNgrams from fonduer.candidates.models import mention_subclass, candidate_subclass from fonduer.candidates.matchers import RegexMatchSpan, Union # Defining ngrams for candidates extraction_name = "age" age_ngrams = MentionNgrams(n_max=3) # Define matchers m = RegexMatchSpan(rgx=r'.*(I|He|She) (is|am) ^([0-9]{2})*') p = RegexMatchSpan(rgx=r'.*(age|is|@|was) ^([0-9]{2})*') q = RegexMatchSpan(rgx=r'.*(age:) ^([0-9]{2})*') r = RegexMatchSpan( rgx=r'.*^([0-9]{2}) (yrs|years|year|yr|old|year-old|yr-old|Years|Year|Yr)*' ) s = RegexMatchSpan(rgx=r'(^|\W)age\W{0,4}[1-9]\d(\W|$)') # Union matchers and create candidate extractor age_matchers = Union(m, p, r, q, s) # Getting candidates AgeMention = mention_subclass("AgeMention")
def main( conn_string, stg_temp_min=False, stg_temp_max=False, polarity=False, ce_v_max=False, max_docs=float("inf"), parse=False, first_time=False, re_label=False, gpu=None, parallel=4, log_dir=None, verbose=False, ): # Setup initial configuration if gpu: os.environ["CUDA_VISIBLE_DEVICES"] = gpu if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) rel_list = [] if stg_temp_min: rel_list.append("stg_temp_min") if stg_temp_max: rel_list.append("stg_temp_max") if polarity: rel_list.append("polarity") if ce_v_max: rel_list.append("ce_v_max") session = Meta.init(conn_string).Session() # Parsing logger.info(f"Starting parsing...") start = timer() docs, train_docs, dev_docs, test_docs = parse_dataset(session, dirname, first_time=parse, parallel=parallel, max_docs=max_docs) end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") # Mention Extraction start = timer() mentions = [] ngrams = [] matchers = [] # Only do those that are enabled Part = mention_subclass("Part") part_matcher = get_matcher("part") part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3) mentions.append(Part) ngrams.append(part_ngrams) matchers.append(part_matcher) if stg_temp_min: StgTempMin = mention_subclass("StgTempMin") stg_temp_min_matcher = get_matcher("stg_temp_min") stg_temp_min_ngrams = MentionNgramsTemp(n_max=2) mentions.append(StgTempMin) ngrams.append(stg_temp_min_ngrams) matchers.append(stg_temp_min_matcher) if stg_temp_max: StgTempMax = mention_subclass("StgTempMax") stg_temp_max_matcher = get_matcher("stg_temp_max") stg_temp_max_ngrams = MentionNgramsTemp(n_max=2) mentions.append(StgTempMax) ngrams.append(stg_temp_max_ngrams) matchers.append(stg_temp_max_matcher) if polarity: Polarity = mention_subclass("Polarity") polarity_matcher = get_matcher("polarity") polarity_ngrams = MentionNgrams(n_max=1) mentions.append(Polarity) ngrams.append(polarity_ngrams) matchers.append(polarity_matcher) if ce_v_max: CeVMax = mention_subclass("CeVMax") ce_v_max_matcher = get_matcher("ce_v_max") ce_v_max_ngrams = MentionNgramsVolt(n_max=1) mentions.append(CeVMax) ngrams.append(ce_v_max_ngrams) matchers.append(ce_v_max_matcher) mention_extractor = MentionExtractor(session, mentions, ngrams, matchers) if first_time: mention_extractor.apply(docs, parallelism=parallel) logger.info(f"Total Mentions: {session.query(Mention).count()}") logger.info(f"Total Part: {session.query(Part).count()}") if stg_temp_min: logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}") if stg_temp_max: logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}") if polarity: logger.info(f"Total Polarity: {session.query(Polarity).count()}") if ce_v_max: logger.info(f"Total CeVMax: {session.query(CeVMax).count()}") # Candidate Extraction cands = [] throttlers = [] if stg_temp_min: PartStgTempMin = candidate_subclass("PartStgTempMin", [Part, StgTempMin]) stg_temp_min_throttler = stg_temp_filter cands.append(PartStgTempMin) throttlers.append(stg_temp_min_throttler) if stg_temp_max: PartStgTempMax = candidate_subclass("PartStgTempMax", [Part, StgTempMax]) stg_temp_max_throttler = stg_temp_filter cands.append(PartStgTempMax) throttlers.append(stg_temp_max_throttler) if polarity: PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity]) polarity_throttler = polarity_filter cands.append(PartPolarity) throttlers.append(polarity_throttler) if ce_v_max: PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax]) ce_v_max_throttler = ce_v_max_filter cands.append(PartCeVMax) throttlers.append(ce_v_max_throttler) candidate_extractor = CandidateExtractor(session, cands, throttlers=throttlers) if first_time: for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=parallel) num_cands = session.query(Candidate).filter( Candidate.split == i).count() logger.info(f"Candidates in split={i}: {num_cands}") train_cands = candidate_extractor.get_candidates(split=0) dev_cands = candidate_extractor.get_candidates(split=1) test_cands = candidate_extractor.get_candidates(split=2) end = timer() logger.warning( f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}") logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}") logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}") logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}") pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl") with open(pickle_file, "rb") as f: parts_by_doc = pickle.load(f) # Check total recall for i, name in enumerate(rel_list): logger.info(name) result = entity_level_scores( candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc), attribute=name, corpus=dev_docs, ) logger.info(f"{name} Total Dev Recall: {result.rec:.3f}") result = entity_level_scores( candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc), attribute=name, corpus=test_docs, ) logger.info(f"{name} Total Test Recall: {result.rec:.3f}") # Featurization start = timer() cands = [] if stg_temp_min: cands.append(PartStgTempMin) if stg_temp_max: cands.append(PartStgTempMax) if polarity: cands.append(PartPolarity) if ce_v_max: cands.append(PartCeVMax) featurizer = Featurizer(session, cands) if first_time: logger.info("Starting featurizer...") featurizer.apply(split=0, train=True, parallelism=parallel) featurizer.apply(split=1, parallelism=parallel) featurizer.apply(split=2, parallelism=parallel) logger.info("Done") logger.info("Getting feature matrices...") if first_time: F_train = featurizer.get_feature_matrices(train_cands) F_dev = featurizer.get_feature_matrices(dev_cands) F_test = featurizer.get_feature_matrices(test_cands) end = timer() logger.warning( f"Featurization Time (min): {((end - start) / 60.0):.1f}") pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb")) pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb")) pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb")) else: F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb")) F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb")) F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb")) logger.info("Done.") for i, cand in enumerate(cands): logger.info(f"{cand} Train shape: {F_train[i].shape}") logger.info(f"{cand} Test shape: {F_test[i].shape}") logger.info(f"{cand} Dev shape: {F_dev[i].shape}") logger.info("Labeling training data...") # Labeling start = timer() lfs = [] if stg_temp_min: lfs.append(stg_temp_min_lfs) if stg_temp_max: lfs.append(stg_temp_max_lfs) if polarity: lfs.append(polarity_lfs) if ce_v_max: lfs.append(ce_v_max_lfs) labeler = Labeler(session, cands) if first_time: logger.info("Applying LFs...") labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel) logger.info("Done...") # Uncomment if debugging LFs # load_transistor_labels(session, cands, ["ce_v_max"]) # labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel) # labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel) elif re_label: logger.info("Updating LFs...") labeler.update(split=0, lfs=lfs, parallelism=parallel) logger.info("Done...") # Uncomment if debugging LFs # labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel) # labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel) logger.info("Getting label matrices...") L_train = labeler.get_label_matrices(train_cands) # Uncomment if debugging LFs # L_dev = labeler.get_label_matrices(dev_cands) # L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold") # # L_test = labeler.get_label_matrices(test_cands) # L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold") logger.info("Done.") end = timer() logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}") start = timer() if stg_temp_min: relation = "stg_temp_min" idx = rel_list.index(relation) marginals_stg_temp_min = generative_model(L_train[idx]) disc_model_stg_temp_min = discriminative_model( train_cands[idx], F_train[idx], marginals_stg_temp_min, n_epochs=100, gpu=gpu, ) best_result, best_b = scoring( relation, disc_model_stg_temp_min, test_cands[idx], test_docs, F_test[idx], parts_by_doc, num=100, ) if stg_temp_max: relation = "stg_temp_max" idx = rel_list.index(relation) marginals_stg_temp_max = generative_model(L_train[idx]) disc_model_stg_temp_max = discriminative_model( train_cands[idx], F_train[idx], marginals_stg_temp_max, n_epochs=100, gpu=gpu, ) best_result, best_b = scoring( relation, disc_model_stg_temp_max, test_cands[idx], test_docs, F_test[idx], parts_by_doc, num=100, ) if polarity: relation = "polarity" idx = rel_list.index(relation) marginals_polarity = generative_model(L_train[idx]) disc_model_polarity = discriminative_model(train_cands[idx], F_train[idx], marginals_polarity, n_epochs=100, gpu=gpu) best_result, best_b = scoring( relation, disc_model_polarity, test_cands[idx], test_docs, F_test[idx], parts_by_doc, num=100, ) if ce_v_max: relation = "ce_v_max" idx = rel_list.index(relation) # Can be uncommented for use in debugging labeling functions # logger.info("Updating labeling function summary...") # keys = labeler.get_keys() # logger.info("Summary for train set labeling functions:") # df = analysis.lf_summary(L_train[idx], lf_names=keys) # logger.info(f"\n{df.to_string()}") # # logger.info("Summary for dev set labeling functions:") # df = analysis.lf_summary( # L_dev[idx], # lf_names=keys, # Y=L_dev_gold[idx].todense().reshape(-1).tolist()[0], # ) # logger.info(f"\n{df.to_string()}") # # logger.info("Summary for test set labeling functions:") # df = analysis.lf_summary( # L_test[idx], # lf_names=keys, # Y=L_test_gold[idx].todense().reshape(-1).tolist()[0], # ) # logger.info(f"\n{df.to_string()}") marginals_ce_v_max = generative_model(L_train[idx]) disc_model_ce_v_max = discriminative_model(train_cands[idx], F_train[idx], marginals_ce_v_max, n_epochs=100, gpu=gpu) # Can be uncommented to view score on development set # best_result, best_b = scoring( # relation, # disc_model_ce_v_max, # dev_cands[idx], # dev_docs, # F_dev[idx], # parts_by_doc, # num=100, # ) best_result, best_b = scoring( relation, disc_model_ce_v_max, test_cands[idx], test_docs, F_test[idx], parts_by_doc, num=100, ) end = timer() logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}") # Dump CSV files for CE_V_MAX for digi-key analysis if ce_v_max: relation = "ce_v_max" idx = rel_list.index(relation) Y_prob = disc_model_ce_v_max.marginals((test_cands[idx], F_test[idx])) dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv") Y_prob = disc_model_ce_v_max.marginals((dev_cands[idx], F_dev[idx])) dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv") # Dump CSV files for POLARITY for digi-key analysis if polarity: relation = "polarity" idx = rel_list.index(relation) Y_prob = disc_model_polarity.marginals((test_cands[idx], F_test[idx])) dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv") Y_prob = disc_model_polarity.marginals((dev_cands[idx], F_dev[idx])) dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")
] stations_mapping_dict = { k: station_list for station_list in stations for k in station_list } stations_list = [s for station_list in stations for s in station_list] station_rgx = '|'.join(stations_list) # 1.) Mention classes Station = mention_subclass("Station") Price = mention_subclass("Price") # 2.) Mention spaces station_ngrams = MentionNgrams( n_max=4, split_tokens=[" ", "_", "\.", "%"]) # StationMentionSpace(n_max=4) # price_ngrams = MentionNgrams(n_max=1) # 3.) Matcher functions station_matcher = RegexMatchFull( rgx=station_rgx, ignore_case=True, # search=True, # full_match=False, # longest_match_only=False, ) # DictionaryMatch(d=stations_list) price_matcher = RegexMatchSpan(rgx=r"\d{1,4}(\.\d{1,5})", longest_match_only=True) # 4.) Candidate classes
def main( conn_string, gain=False, current=False, max_docs=float("inf"), parse=False, first_time=False, re_label=False, parallel=8, log_dir="logs", verbose=False, ): # Setup initial configuration if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) rel_list = [] if gain: rel_list.append("gain") if current: rel_list.append("current") logger.info(f"=" * 30) logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}") session = Meta.init(conn_string).Session() # Parsing start = timer() logger.info(f"Starting parsing...") docs, train_docs, dev_docs, test_docs = parse_dataset(session, dirname, first_time=parse, parallel=parallel, max_docs=max_docs) logger.debug(f"Done") end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of Documents: {len(docs)}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") # Mention Extraction start = timer() mentions = [] ngrams = [] matchers = [] # Only do those that are enabled if gain: Gain = mention_subclass("Gain") gain_matcher = get_gain_matcher() gain_ngrams = MentionNgrams(n_max=2) mentions.append(Gain) ngrams.append(gain_ngrams) matchers.append(gain_matcher) if current: Current = mention_subclass("SupplyCurrent") current_matcher = get_supply_current_matcher() current_ngrams = MentionNgramsCurrent(n_max=3) mentions.append(Current) ngrams.append(current_ngrams) matchers.append(current_matcher) mention_extractor = MentionExtractor(session, mentions, ngrams, matchers) if first_time: mention_extractor.apply(docs, parallelism=parallel) logger.info(f"Total Mentions: {session.query(Mention).count()}") if gain: logger.info(f"Total Gain: {session.query(Gain).count()}") if current: logger.info(f"Total Current: {session.query(Current).count()}") cand_classes = [] if gain: GainCand = candidate_subclass("GainCand", [Gain]) cand_classes.append(GainCand) if current: CurrentCand = candidate_subclass("CurrentCand", [Current]) cand_classes.append(CurrentCand) candidate_extractor = CandidateExtractor(session, cand_classes) if first_time: for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=parallel) # These must be sorted for deterministic behavior. train_cands = candidate_extractor.get_candidates(split=0, sort=True) dev_cands = candidate_extractor.get_candidates(split=1, sort=True) test_cands = candidate_extractor.get_candidates(split=2, sort=True) logger.info( f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}") logger.info( f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}") logger.info( f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}") logger.info("Done w/ candidate extraction.") end = timer() logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}") # First, check total recall # result = entity_level_scores( # candidates_to_entities(dev_cands[0], is_gain=True), # corpus=dev_docs, # is_gain=True, # ) # logger.info(f"Gain Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores( # candidates_to_entities(test_cands[0], is_gain=True), # corpus=test_docs, # is_gain=True, # ) # logger.info(f"Gain Total Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # # result = entity_level_scores( # candidates_to_entities(dev_cands[1], is_gain=False), # corpus=dev_docs, # is_gain=False, # ) # logger.info(f"Current Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores( # candidates_to_entities(test_cands[1], is_gain=False), # corpus=test_docs, # is_gain=False, # ) # logger.info(f"Current Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") start = timer() # Using parallelism = 1 for deterministic behavior. featurizer = Featurizer(session, cand_classes, parallelism=1) if first_time: logger.info("Starting featurizer...") # Set feature space based on dev set, which we use for training rather # than the large train set. featurizer.apply(split=1, train=True) featurizer.apply(split=0) featurizer.apply(split=2) logger.info("Done") logger.info("Getting feature matrices...") # Serialize feature matrices on first run if first_time: F_train = featurizer.get_feature_matrices(train_cands) F_dev = featurizer.get_feature_matrices(dev_cands) F_test = featurizer.get_feature_matrices(test_cands) end = timer() logger.warning( f"Featurization Time (min): {((end - start) / 60.0):.1f}") F_train_dict = {} F_dev_dict = {} F_test_dict = {} for idx, relation in enumerate(rel_list): F_train_dict[relation] = F_train[idx] F_dev_dict[relation] = F_dev[idx] F_test_dict[relation] = F_test[idx] pickle.dump(F_train_dict, open(os.path.join(dirname, "F_train_dict.pkl"), "wb")) pickle.dump(F_dev_dict, open(os.path.join(dirname, "F_dev_dict.pkl"), "wb")) pickle.dump(F_test_dict, open(os.path.join(dirname, "F_test_dict.pkl"), "wb")) else: F_train_dict = pickle.load( open(os.path.join(dirname, "F_train_dict.pkl"), "rb")) F_dev_dict = pickle.load( open(os.path.join(dirname, "F_dev_dict.pkl"), "rb")) F_test_dict = pickle.load( open(os.path.join(dirname, "F_test_dict.pkl"), "rb")) F_train = [] F_dev = [] F_test = [] for relation in rel_list: F_train.append(F_train_dict[relation]) F_dev.append(F_dev_dict[relation]) F_test.append(F_test_dict[relation]) logger.info("Done.") start = timer() logger.info("Labeling training data...") # labeler = Labeler(session, cand_classes) # lfs = [] # if gain: # lfs.append(gain_lfs) # # if current: # lfs.append(current_lfs) # # if first_time: # logger.info("Applying LFs...") # labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel) # elif re_label: # logger.info("Re-applying LFs...") # labeler.update(split=0, lfs=lfs, parallelism=parallel) # # logger.info("Done...") # logger.info("Getting label matrices...") # L_train = labeler.get_label_matrices(train_cands) # logger.info("Done...") if first_time: marginals_dict = {} for idx, relation in enumerate(rel_list): # Manually create marginals from human annotations marginal = [] dev_gold_entities = get_gold_set(is_gain=(relation == "gain")) for c in dev_cands[idx]: flag = False for entity in cand_to_entity(c, is_gain=(relation == "gain")): if entity in dev_gold_entities: flag = True if flag: marginal.append([0.0, 1.0]) else: marginal.append([1.0, 0.0]) marginals_dict[relation] = np.array(marginal) pickle.dump(marginals_dict, open(os.path.join(dirname, "marginals_dict.pkl"), "wb")) else: marginals_dict = pickle.load( open(os.path.join(dirname, "marginals_dict.pkl"), "rb")) marginals = [] for relation in rel_list: marginals.append(marginals_dict[relation]) end = timer() logger.warning( f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}") start = timer() word_counter = collect_word_counter(train_cands) # Training config config = { "meta_config": { "verbose": True, "seed": 30 }, "model_config": { "model_path": None, "device": 0, "dataparallel": False }, "learner_config": { "n_epochs": 500, "optimizer_config": { "lr": 0.001, "l2": 0.005 }, "task_scheduler": "round_robin", }, "logging_config": { "evaluation_freq": 1, "counter_unit": "epoch", "checkpointing": False, "checkpointer_config": { "checkpoint_metric": { "model/all/train/loss": "min" }, "checkpoint_freq": 1, "checkpoint_runway": 2, "clear_intermediate_checkpoints": True, "clear_all_checkpoints": True, }, }, } emmental.init(log_dir=Meta.log_path, config=config) # Generate word embedding module arity = 2 # Geneate special tokens specials = [] for i in range(arity): specials += [f"~~[[{i}", f"{i}]]~~"] emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=300, specials=specials) train_idxs = [] train_dataloader = [] for idx, relation in enumerate(rel_list): diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1) train_idxs.append(np.where(diffs > 1e-6)[0]) # only uses dev set as training data, with human annotations train_dataloader.append( EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset( relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, marginals[idx], train_idxs[idx], ), split="train", batch_size=256, shuffle=True, )) num_feature_keys = len(featurizer.get_keys()) model = EmmentalModel(name=f"opamp_tasks") # List relation names, arities, list of classes tasks = create_task( rel_list, [2] * len(rel_list), num_feature_keys, [2] * len(rel_list), emb_layer, model="LogisticRegression", ) for task in tasks: model.add_task(task) emmental_learner = EmmentalLearner() # If given a list of multi, will train on multiple emmental_learner.learn(model, train_dataloader) # List of dataloader for each relation for idx, relation in enumerate(rel_list): test_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, test_cands[idx], F_test[idx], emb_layer.word2id, 2), split="test", batch_size=256, shuffle=False, ) test_preds = model.predict(test_dataloader, return_preds=True) best_result, best_b = scoring( test_preds, test_cands[idx], test_docs, is_gain=(relation == "gain"), num=100, ) # Dump CSV files for analysis if relation == "gain": train_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, train_cands[idx], F_train[idx], emb_layer.word2id, 2), split="train", batch_size=256, shuffle=False, ) train_preds = model.predict(train_dataloader, return_preds=True) Y_prob = np.array(train_preds["probs"][relation])[:, TRUE] output_csv(train_cands[idx], Y_prob, is_gain=True) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] output_csv(test_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(test_cands[idx], Y_prob, "gain_test_probs.csv", is_gain=True) dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=256, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(dev_cands[idx], Y_prob, "gain_dev_probs.csv", is_gain=True) if relation == "current": train_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, train_cands[idx], F_train[idx], emb_layer.word2id, 2), split="train", batch_size=256, shuffle=False, ) train_preds = model.predict(train_dataloader, return_preds=True) Y_prob = np.array(train_preds["probs"][relation])[:, TRUE] output_csv(train_cands[idx], Y_prob, is_gain=False) Y_prob = np.array(test_preds["probs"][relation])[:, TRUE] output_csv(test_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(test_cands[idx], Y_prob, "current_test_probs.csv", is_gain=False) dev_dataloader = EmmentalDataLoader( task_to_label_dict={relation: "labels"}, dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx], emb_layer.word2id, 2), split="dev", batch_size=256, shuffle=False, ) dev_preds = model.predict(dev_dataloader, return_preds=True) Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE] output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(dev_cands[idx], Y_prob, "current_dev_probs.csv", is_gain=False) end = timer() logger.warning( f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}" )
def main( conn_string, gain=False, current=False, max_docs=float("inf"), parse=False, first_time=False, re_label=False, gpu=None, parallel=8, log_dir="logs", verbose=False, ): # Setup initial configuration if gpu: os.environ["CUDA_VISIBLE_DEVICES"] = gpu if not log_dir: log_dir = "logs" if verbose: level = logging.INFO else: level = logging.WARNING dirname = os.path.dirname(os.path.abspath(__file__)) init_logging(log_dir=os.path.join(dirname, log_dir), level=level) rel_list = [] if gain: rel_list.append("gain") if current: rel_list.append("current") logger.info(f"=" * 30) logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}") session = Meta.init(conn_string).Session() # Parsing start = timer() logger.info(f"Starting parsing...") docs, train_docs, dev_docs, test_docs = parse_dataset(session, dirname, first_time=parse, parallel=parallel, max_docs=max_docs) logger.debug(f"Done") end = timer() logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}") logger.info(f"# of Documents: {len(docs)}") logger.info(f"# of train Documents: {len(train_docs)}") logger.info(f"# of dev Documents: {len(dev_docs)}") logger.info(f"# of test Documents: {len(test_docs)}") logger.info(f"Documents: {session.query(Document).count()}") logger.info(f"Sections: {session.query(Section).count()}") logger.info(f"Paragraphs: {session.query(Paragraph).count()}") logger.info(f"Sentences: {session.query(Sentence).count()}") logger.info(f"Figures: {session.query(Figure).count()}") # Mention Extraction start = timer() mentions = [] ngrams = [] matchers = [] # Only do those that are enabled if gain: Gain = mention_subclass("Gain") gain_matcher = get_gain_matcher() gain_ngrams = MentionNgrams(n_max=2) mentions.append(Gain) ngrams.append(gain_ngrams) matchers.append(gain_matcher) if current: Current = mention_subclass("SupplyCurrent") current_matcher = get_supply_current_matcher() current_ngrams = MentionNgramsCurrent(n_max=3) mentions.append(Current) ngrams.append(current_ngrams) matchers.append(current_matcher) mention_extractor = MentionExtractor(session, mentions, ngrams, matchers) if first_time: mention_extractor.apply(docs, parallelism=parallel) logger.info(f"Total Mentions: {session.query(Mention).count()}") if gain: logger.info(f"Total Gain: {session.query(Gain).count()}") if current: logger.info(f"Total Current: {session.query(Current).count()}") cand_classes = [] if gain: GainCand = candidate_subclass("GainCand", [Gain]) cand_classes.append(GainCand) if current: CurrentCand = candidate_subclass("CurrentCand", [Current]) cand_classes.append(CurrentCand) candidate_extractor = CandidateExtractor(session, cand_classes) if first_time: for i, docs in enumerate([train_docs, dev_docs, test_docs]): candidate_extractor.apply(docs, split=i, parallelism=parallel) train_cands = candidate_extractor.get_candidates(split=0) dev_cands = candidate_extractor.get_candidates(split=1) test_cands = candidate_extractor.get_candidates(split=2) logger.info( f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}") logger.info( f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}") logger.info( f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}") logger.info("Done w/ candidate extraction.") end = timer() logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}") # First, check total recall # result = entity_level_scores(dev_cands[0], corpus=dev_docs) # logger.info(f"Gain Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores(test_cands[0], corpus=test_docs) # logger.info(f"Gain Total Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # # result = entity_level_scores(dev_cands[1], corpus=dev_docs, is_gain=False) # logger.info(f"Current Total Dev Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") # result = entity_level_scores(test_cands[1], corpus=test_docs, is_gain=False) # logger.info(f"Current Test Recall: {result.rec:.3f}") # logger.info(f"\n{pformat(result.FN)}") start = timer() featurizer = Featurizer(session, cand_classes) if first_time: logger.info("Starting featurizer...") featurizer.apply(split=0, train=True, parallelism=parallel) featurizer.apply(split=1, parallelism=parallel) featurizer.apply(split=2, parallelism=parallel) logger.info("Done") logger.info("Getting feature matrices...") # Serialize feature matrices on first run if first_time: F_train = featurizer.get_feature_matrices(train_cands) F_dev = featurizer.get_feature_matrices(dev_cands) F_test = featurizer.get_feature_matrices(test_cands) end = timer() logger.warning( f"Featurization Time (min): {((end - start) / 60.0):.1f}") pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb")) pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb")) pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb")) else: F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb")) F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb")) F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb")) logger.info("Done.") start = timer() logger.info("Labeling training data...") labeler = Labeler(session, cand_classes) lfs = [] if gain: lfs.append(gain_lfs) if current: lfs.append(current_lfs) if first_time: logger.info("Applying LFs...") labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel) elif re_label: logger.info("Re-applying LFs...") labeler.update(split=0, lfs=lfs, parallelism=parallel) logger.info("Done...") logger.info("Getting label matrices...") L_train = labeler.get_label_matrices(train_cands) logger.info("Done...") end = timer() logger.warning( f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}") if gain: relation = "gain" idx = rel_list.index(relation) logger.info("Score Gain.") dev_gold_entities = get_gold_set(is_gain=True) L_dev_gt = [] for c in dev_cands[idx]: flag = FALSE for entity in cand_to_entity(c, is_gain=True): if entity in dev_gold_entities: flag = TRUE L_dev_gt.append(flag) marginals = generative_model(L_train[idx]) disc_models = discriminative_model( train_cands[idx], F_train[idx], marginals, X_dev=(dev_cands[idx], F_dev[idx]), Y_dev=L_dev_gt, n_epochs=500, gpu=gpu, ) best_result, best_b = scoring(disc_models, test_cands[idx], test_docs, F_test[idx], num=50) print_scores(relation, best_result, best_b) logger.info("Output CSV files for Opo and Digi-key Analysis.") Y_prob = disc_models.marginals((train_cands[idx], F_train[idx])) output_csv(train_cands[idx], Y_prob, is_gain=True) Y_prob = disc_models.marginals((test_cands[idx], F_test[idx])) output_csv(test_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(test_cands[idx], Y_prob, "gain_test_probs.csv", is_gain=True) Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx])) output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True) dump_candidates(dev_cands[idx], Y_prob, "gain_dev_probs.csv", is_gain=True) if current: relation = "current" idx = rel_list.index(relation) logger.info("Score Current.") dev_gold_entities = get_gold_set(is_gain=False) L_dev_gt = [] for c in dev_cands[idx]: flag = FALSE for entity in cand_to_entity(c, is_gain=False): if entity in dev_gold_entities: flag = TRUE L_dev_gt.append(flag) marginals = generative_model(L_train[idx]) disc_models = discriminative_model( train_cands[idx], F_train[idx], marginals, X_dev=(dev_cands[idx], F_dev[idx]), Y_dev=L_dev_gt, n_epochs=100, gpu=gpu, ) best_result, best_b = scoring(disc_models, test_cands[idx], test_docs, F_test[idx], is_gain=False, num=50) print_scores(relation, best_result, best_b) logger.info("Output CSV files for Opo and Digi-key Analysis.") # Dump CSV files for digi-key analysis and Opo comparison Y_prob = disc_models.marginals((train_cands[idx], F_train[idx])) output_csv(train_cands[idx], Y_prob, is_gain=False) Y_prob = disc_models.marginals((test_cands[idx], F_test[idx])) output_csv(test_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(test_cands[idx], Y_prob, "current_test_probs.csv", is_gain=False) Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx])) output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True) dump_candidates(dev_cands[idx], Y_prob, "current_dev_probs.csv", is_gain=False) end = timer() logger.warning( f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}" )
# In[12]: birth_place_in_labeled_row_matcher = LambdaFunctionMatcher( func=is_in_birthplace_table_row) birth_place_in_labeled_row_matcher.longest_match_only = False birth_place_no_commas_matcher = LambdaFunctionMatcher( func=no_commas_in_birth_place) birth_place_left_aligned_matcher = LambdaFunctionMatcher( func=birthplace_left_aligned_to_punctuation) place_of_birth_matcher = Intersect( birth_place_in_labeled_row_matcher, birth_place_no_commas_matcher, birth_place_left_aligned_matcher, ) from fonduer.candidates import MentionNgrams presname_ngrams = MentionNgrams(n_max=4, n_min=2) placeofbirth_ngrams = MentionNgrams(n_max=3) from fonduer.candidates.models import candidate_subclass PresidentnamePlaceofbirth = candidate_subclass("PresidentnamePlaceofbirth", [Presidentname, Placeofbirth]) mention_classes = [Presidentname, Placeofbirth] mention_spaces = [presname_ngrams, placeofbirth_ngrams] matchers = [president_name_matcher, place_of_birth_matcher] candidate_classes = [PresidentnamePlaceofbirth]
def test_unary_relation_feature_extraction(): """Test extracting unary candidates from mentions from documents.""" docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" # Parsing doc = parse_doc(docs_path, "112823", pdf_path) assert len(doc.sentences) == 799 # Mention Extraction part_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") mention_extractor_udf = MentionExtractorUDF([Part], [part_ngrams], [part_matcher]) doc = mention_extractor_udf.apply(doc) assert doc.name == "112823" assert len(doc.parts) == 62 part = doc.parts[0] logger.info(f"Part: {part.context}") # Candidate Extraction PartRel = candidate_subclass("PartRel", [Part]) candidate_extractor_udf = CandidateExtractorUDF([PartRel], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) # Featurization based on default feature library featurizer_udf = FeaturizerUDF([PartRel], FeatureExtractor()) # Test that featurization default feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_default_feats = len(key_set) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization textual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_textual_features = len(key_set) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization tabular feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_tabular_features = len(key_set) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization structural feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_structural_features = len(key_set) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer_udf = FeaturizerUDF([PartRel], feature_extractors=feature_extractors) # Test that featurization visual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_visual_features = len(key_set) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features)
def test_multinary_relation_feature_extraction(): """Test extracting candidates from mentions from documents.""" docs_path = "tests/data/html/112823.html" pdf_path = "tests/data/pdf/112823.pdf" # Parsing doc = parse_doc(docs_path, "112823", pdf_path) assert len(doc.sentences) == 799 # Mention Extraction part_ngrams = MentionNgrams(n_max=1) temp_ngrams = MentionNgrams(n_max=1) volt_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") Temp = mention_subclass("Temp") Volt = mention_subclass("Volt") mention_extractor_udf = MentionExtractorUDF( [Part, Temp, Volt], [part_ngrams, temp_ngrams, volt_ngrams], [part_matcher, temp_matcher, volt_matcher], ) doc = mention_extractor_udf.apply(doc) assert len(doc.parts) == 62 assert len(doc.temps) == 16 assert len(doc.volts) == 33 part = doc.parts[0] temp = doc.temps[0] volt = doc.volts[0] logger.info(f"Part: {part.context}") logger.info(f"Temp: {temp.context}") logger.info(f"Volt: {volt.context}") # Candidate Extraction PartTempVolt = candidate_subclass("PartTempVolt", [Part, Temp, Volt]) candidate_extractor_udf = CandidateExtractorUDF([PartTempVolt], None, False, False, True) doc = candidate_extractor_udf.apply(doc, split=0) # Manually set id as it is not set automatically b/c a database is not used. i = 0 for cand in doc.part_temp_volts: cand.id = i i = i + 1 n_cands = len(doc.part_temp_volts) # Featurization based on default feature library featurizer_udf = FeaturizerUDF([PartTempVolt], FeatureExtractor()) # Test that featurization default feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_default_feats = len(key_set) # Example feature extractor def feat_ext(candidates): candidates = candidates if isinstance(candidates, list) else [candidates] for candidate in candidates: yield candidate.id, f"cand_id_{candidate.id}", 1 # Featurization with one extra feature extractor feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_default_w_customized_features = len(key_set) # Example spurious feature extractor def bad_feat_ext(candidates): raise RuntimeError() # Featurization with a spurious feature extractor feature_extractors = FeatureExtractor( customize_feature_funcs=[bad_feat_ext]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization default feature library with one extra feature extractor logger.info("Featurizing with a spurious feature extractor...") with pytest.raises(RuntimeError): features = featurizer_udf.apply(doc) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization textual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_textual_features = len(key_set) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization tabular feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_tabular_features = len(key_set) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization structural feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_structural_features = len(key_set) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer_udf = FeaturizerUDF([PartTempVolt], feature_extractors=feature_extractors) # Test that featurization visual feature library features_list = featurizer_udf.apply(doc) features = itertools.chain.from_iterable(features_list) key_set = set([key for feature in features for key in feature["keys"]]) n_visual_features = len(key_set) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features) assert n_default_w_customized_features == n_default_feats + n_cands
def test_unary_relation_feature_extraction(): """Test extracting unary candidates from mentions from documents.""" PARALLEL = 1 max_docs = 1 session = Meta.init(CONN_STRING).Session() docs_path = "tests/data/html/" pdf_path = "tests/data/pdf/" # Parsing logger.info("Parsing...") doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs) corpus_parser = Parser(session, structural=True, lingual=True, visual=True, pdf_path=pdf_path) corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL) assert session.query(Document).count() == max_docs assert session.query(Sentence).count() == 799 docs = session.query(Document).order_by(Document.name).all() # Mention Extraction part_ngrams = MentionNgrams(n_max=1) Part = mention_subclass("Part") mention_extractor = MentionExtractor(session, [Part], [part_ngrams], [part_matcher]) mention_extractor.apply(docs, parallelism=PARALLEL) assert docs[0].name == "112823" assert session.query(Part).count() == 58 part = session.query(Part).order_by(Part.id).all()[0] logger.info(f"Part: {part.context}") # Candidate Extraction PartRel = candidate_subclass("PartRel", [Part]) candidate_extractor = CandidateExtractor(session, [PartRel]) candidate_extractor.apply(docs, split=0, parallelism=PARALLEL) # Featurization based on default feature library featurizer = Featurizer(session, [PartRel]) # Test that featurization default feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_default_feats = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only textual feature feature_extractors = FeatureExtractor(features=["textual"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization textual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_textual_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only tabular feature feature_extractors = FeatureExtractor(features=["tabular"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization tabular feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_tabular_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only structural feature feature_extractors = FeatureExtractor(features=["structural"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization structural feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_structural_features = session.query(FeatureKey).count() featurizer.clear(train=True) # Featurization with only visual feature feature_extractors = FeatureExtractor(features=["visual"]) featurizer = Featurizer(session, [PartRel], feature_extractors=feature_extractors) # Test that featurization visual feature library featurizer.apply(split=0, train=True, parallelism=PARALLEL) n_visual_features = session.query(FeatureKey).count() featurizer.clear(train=True) assert (n_default_feats == n_textual_features + n_tabular_features + n_structural_features + n_visual_features)