Exemplo n.º 1
0
def test_ngrams():
    """Test ngram limits in mention extraction"""
    file_name = "lincoln_short"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    # Mention Extraction
    Person = mention_subclass("Person")
    person_ngrams = MentionNgrams(n_max=3)
    person_matcher = PersonMatcher()

    mention_extractor_udf = MentionExtractorUDF(
        [Person], [person_ngrams], [person_matcher]
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.persons) == 118
    mentions = doc.persons
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0

    # Test for unigram exclusion
    for mention in doc.persons[:]:
        doc.persons.remove(mention)
    assert len(doc.persons) == 0

    person_ngrams = MentionNgrams(n_min=2, n_max=3)
    mention_extractor_udf = MentionExtractorUDF(
        [Person], [person_ngrams], [person_matcher]
    )
    doc = mention_extractor_udf.apply(doc)
    assert len(doc.persons) == 69
    mentions = doc.persons
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
 def __init__(
     self, n_min: int = 1, n_max: int = 5, split_tokens: Collection[str] = [], types: Optional[str] = None
 ) -> None:
     """Initialize MentionNgrams."""
     MentionNgrams.__init__(self, n_min=n_min, n_max=n_max, split_tokens=split_tokens)
     if types is not None:
         self.types = [t.strip().lower() for t in types]
     else:
         self.types = None
Exemplo n.º 3
0
    def __init__(self,
                 parts_by_doc=None,
                 n_max=3,
                 expand=True,
                 split_tokens=None):
        """MentionNgrams specifically for transistor parts.

        :param parts_by_doc: a dictionary d where d[document_name.upper()] =
            [partA, partB, ...]
        """
        MentionNgrams.__init__(self, n_max=n_max, split_tokens=None)
        self.parts_by_doc = parts_by_doc
        self.expander = expand_part_range if expand else (lambda x: [x])
Exemplo n.º 4
0
def test_visualizer():
    from fonduer.utils.visualizer import Visualizer  # noqa
    """Unit test of visualizer using the md document.
    """
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/md.pdf"

    # Grab the md document
    doc = parse_doc(docs_path, "md", pdf_path)
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams],
                                                [organization_matcher])

    doc = mention_extractor_udf.apply(doc)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor_udf = CandidateExtractorUDF([Organization], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    cands = doc.organizations

    # Test visualizer
    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)
    vis.display_candidates([cands[0]])
Exemplo n.º 5
0
 def apply(self, doc):
     for ts in MentionNgrams.apply(self, doc):
         if ts.get_span().endswith(".0"):
             value = ts.get_span()[:-2]
             yield TemporaryImplicitSpanMention(
                 sentence=ts.sentence,
                 char_start=ts.char_start,
                 char_end=ts.char_end,
                 expander_key="volt_expander",
                 position=0,
                 text=value,
                 words=[value],
                 lemmas=[value],
                 pos_tags=[ts.get_attrib_tokens("pos_tags")[-1]],
                 ner_tags=[ts.get_attrib_tokens("ner_tags")[-1]],
                 dep_parents=[ts.get_attrib_tokens("dep_parents")[-1]],
                 dep_labels=[ts.get_attrib_tokens("dep_labels")[-1]],
                 page=[ts.get_attrib_tokens("page")[-1]]
                 if ts.sentence.is_visual() else [None],
                 top=[ts.get_attrib_tokens("top")[-1]]
                 if ts.sentence.is_visual() else [None],
                 left=[ts.get_attrib_tokens("left")[-1]]
                 if ts.sentence.is_visual() else [None],
                 bottom=[ts.get_attrib_tokens("bottom")[-1]]
                 if ts.sentence.is_visual() else [None],
                 right=[ts.get_attrib_tokens("right")[-1]]
                 if ts.sentence.is_visual() else [None],
                 meta=None,
             )
         else:
             yield ts
Exemplo n.º 6
0
def test_ngrams(caplog):
    """Test ngram limits in mention extraction"""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/lincoln_short.html"

    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 503
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    Person = mention_subclass("Person")
    person_ngrams = MentionNgrams(n_max=3)
    person_matcher = PersonMatcher()

    mention_extractor = MentionExtractor(
        session, [Person], [person_ngrams], [person_matcher]
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Person).count() == 118
    mentions = session.query(Person).all()
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 49
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0

    # Test for unigram exclusion
    person_ngrams = MentionNgrams(n_min=2, n_max=3)
    mention_extractor = MentionExtractor(
        session, [Person], [person_ngrams], [person_matcher]
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)
    assert session.query(Person).count() == 69
    mentions = session.query(Person).all()
    assert len([x for x in mentions if x.context.get_num_words() == 1]) == 0
    assert len([x for x in mentions if x.context.get_num_words() > 3]) == 0
Exemplo n.º 7
0
 def apply(self, session, context):
     for ts in MentionNgrams.apply(self, session, context):
         m = re.match(
             r"^([\+\-\u2010\u2011\u2012\u2013\u2014\u2212\uf02d])?(\s*)(\d+)$",
             ts.get_span(),
             re.U,
         )
         if m:
             if m.group(1) is None:
                 temp = ""
             elif m.group(1) == "+":
                 if m.group(2) != "":
                     # If bigram '+ 150' is seen, accept the unigram '150',
                     # not both
                     continue
                 temp = ""
             else:  # m.group(1) is a type of negative sign
                 # A bigram '- 150' is different from unigram '150', so we
                 # keep the implicit '-150'
                 temp = "-"
             temp += m.group(3)
             yield TemporaryImplicitSpan(
                 sentence=ts.sentence,
                 char_start=ts.char_start,
                 char_end=ts.char_end,
                 expander_key=u"temp_expander",
                 position=0,
                 text=temp,
                 words=[temp],
                 lemmas=[temp],
                 pos_tags=[ts.get_attrib_tokens("pos_tags")[-1]],
                 ner_tags=[ts.get_attrib_tokens("ner_tags")[-1]],
                 dep_parents=[ts.get_attrib_tokens("dep_parents")[-1]],
                 dep_labels=[ts.get_attrib_tokens("dep_labels")[-1]],
                 page=[ts.get_attrib_tokens("page")[-1]]
                 if ts.sentence.is_visual()
                 else [None],
                 top=[ts.get_attrib_tokens("top")[-1]]
                 if ts.sentence.is_visual()
                 else [None],
                 left=[ts.get_attrib_tokens("left")[-1]]
                 if ts.sentence.is_visual()
                 else [None],
                 bottom=[ts.get_attrib_tokens("bottom")[-1]]
                 if ts.sentence.is_visual()
                 else [None],
                 right=[ts.get_attrib_tokens("right")[-1]]
                 if ts.sentence.is_visual()
                 else [None],
                 meta=None,
             )
         else:
             yield ts
Exemplo n.º 8
0
def test_visualizer(caplog):
    from fonduer.utils.visualizer import Visualizer  # noqa
    """Unit test of visualizer using the md document.
    """
    caplog.set_level(logging.INFO)
    session = Meta.init("postgresql://localhost:5432/" + ATTRIBUTE).Session()

    PARALLEL = 1
    max_docs = 1
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/md.pdf"

    # Preprocessor for the Docs
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    # Create an Parser and parse the md document
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    # Grab the md document
    doc = session.query(Document).order_by(Document.name).all()[0]
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor = MentionExtractor(session, [Org], [organization_ngrams],
                                         [organization_matcher])

    mention_extractor.apply([doc], parallelism=PARALLEL)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor = CandidateExtractor(session, [Organization])

    candidate_extractor.apply([doc], split=0, parallelism=PARALLEL)

    cands = session.query(Organization).filter(Organization.split == 0).all()

    # Test visualizer
    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)
    vis.display_candidates([cands[0]])
Exemplo n.º 9
0
def mention_setup():
    """Set up mentions."""
    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Preprocessor for the Docs
    preprocessor = HTMLDocPreprocessor(docs_path)
    doc = next(preprocessor.__iter__())

    # Create an Parser and parse the md document
    parser_udf = get_parser_udf(
        structural=True,
        tabular=True,
        lingual=True,
        visual=True,
        visual_parser=PdfVisualParser(pdf_path),
        language="en",
    )
    doc = parser_udf.apply(doc)

    # Create 1-gram span mentions
    space = MentionNgrams(n_min=1, n_max=1)
    mentions = [tc for tc in space.apply(doc)]
    return mentions
Exemplo n.º 10
0
 def apply(self, session, context):
     for ts in MentionNgrams.apply(self, session, context):
         enumerated_parts = [
             part.upper() for part in expand_part_range(ts.get_span())
         ]
         parts = set(enumerated_parts)
         if self.parts_by_doc:
             possible_parts = self.parts_by_doc[ts.parent.document.name.upper()]
             for base_part in enumerated_parts:
                 for part in possible_parts:
                     if part.startswith(base_part) and len(base_part) >= 4:
                         parts.add(part)
         for i, part in enumerate(parts):
             if " " in part:
                 continue  # it won't pass the part_matcher
             if part == ts.get_span():
                 yield ts
             else:
                 yield TemporaryImplicitSpan(
                     sentence=ts.sentence,
                     char_start=ts.char_start,
                     char_end=ts.char_end,
                     expander_key=u"part_expander",
                     position=i,
                     text=part,
                     words=[part],
                     lemmas=[part],
                     pos_tags=[ts.get_attrib_tokens("pos_tags")[0]],
                     ner_tags=[ts.get_attrib_tokens("ner_tags")[0]],
                     dep_parents=[ts.get_attrib_tokens("dep_parents")[0]],
                     dep_labels=[ts.get_attrib_tokens("dep_labels")[0]],
                     page=[min(ts.get_attrib_tokens("page"))]
                     if ts.sentence.is_visual()
                     else [None],
                     top=[min(ts.get_attrib_tokens("top"))]
                     if ts.sentence.is_visual()
                     else [None],
                     left=[max(ts.get_attrib_tokens("left"))]
                     if ts.sentence.is_visual()
                     else [None],
                     bottom=[min(ts.get_attrib_tokens("bottom"))]
                     if ts.sentence.is_visual()
                     else [None],
                     right=[max(ts.get_attrib_tokens("right"))]
                     if ts.sentence.is_visual()
                     else [None],
                     meta=None,
                 )
Exemplo n.º 11
0
def test_visualizer():
    """Unit test of visualizer using the md document."""
    from fonduer.utils.visualizer import Visualizer, get_box  # noqa

    docs_path = "tests/data/html_simple/md.html"
    pdf_path = "tests/data/pdf_simple/"

    # Grab the md document
    doc = parse_doc(docs_path, "md", pdf_path)
    assert doc.name == "md"

    organization_ngrams = MentionNgrams(n_max=1)

    Org = mention_subclass("Org")

    organization_matcher = OrganizationMatcher()

    mention_extractor_udf = MentionExtractorUDF([Org], [organization_ngrams],
                                                [organization_matcher])

    doc = mention_extractor_udf.apply(doc)

    Organization = candidate_subclass("Organization", [Org])

    candidate_extractor_udf = CandidateExtractorUDF([Organization], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Take one candidate
    cand = doc.organizations[0]

    pdf_path = "tests/data/pdf_simple"
    vis = Visualizer(pdf_path)

    # Test bounding boxes
    boxes = [get_box(mention.context) for mention in cand.get_mentions()]
    for box in boxes:
        assert box.top <= box.bottom
        assert box.left <= box.right
    assert boxes == [
        mention.context.get_bbox() for mention in cand.get_mentions()
    ]

    # Test visualizer
    vis.display_candidates([cand])
Exemplo n.º 12
0
    def apply(self, doc):
        for ts in MentionNgrams.apply(self, doc):
            m = re.match(r"^(±)?\s*(\d+)\s*(\.)?\s*(\d*)$", ts.get_span())
            if m:
                # Handle case that random spaces are inserted (e.g. "± 2  . 3")
                temp = ""
                if m.group(1):
                    temp += m.group(1)
                if m.group(2):
                    temp += m.group(2)
                if m.group(3):
                    temp += m.group(3)
                if m.group(4):
                    temp += m.group(4)

                yield TemporaryImplicitSpanMention(
                    sentence=ts.sentence,
                    char_start=ts.char_start,
                    char_end=ts.char_end,
                    expander_key="opamp_exp",
                    position=0,
                    text=temp,
                    words=[temp],
                    lemmas=[temp],
                    pos_tags=[ts.get_attrib_tokens("pos_tags")[-1]],
                    ner_tags=[ts.get_attrib_tokens("ner_tags")[-1]],
                    dep_parents=[ts.get_attrib_tokens("dep_parents")[-1]],
                    dep_labels=[ts.get_attrib_tokens("dep_labels")[-1]],
                    page=[ts.get_attrib_tokens("page")[-1]]
                    if ts.sentence.is_visual() else [None],
                    top=[ts.get_attrib_tokens("top")[-1]]
                    if ts.sentence.is_visual() else [None],
                    left=[ts.get_attrib_tokens("left")[-1]]
                    if ts.sentence.is_visual() else [None],
                    bottom=[ts.get_attrib_tokens("bottom")[-1]]
                    if ts.sentence.is_visual() else [None],
                    right=[ts.get_attrib_tokens("right")[-1]]
                    if ts.sentence.is_visual() else [None],
                    meta=None,
                )
            else:
                yield ts
Exemplo n.º 13
0
    def apply(self, doc: Document) -> Iterator[TemporaryContext]:
        """Generate MentionNgrams from a Document by parsing all of its Sentences.

        :param doc: The ``Document`` to parse.
        :type doc: ``Document``
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionNgrams.apply() must be of type Document"
            )

        for ts in MentionNgrams.apply(self, doc):
            yield ts
            
        for ts in MentionDocuments.apply(self, doc):
            yield ts
            
        for ts in MentionFigures.apply(self, doc):
            yield ts
Exemplo n.º 14
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cands, parallelism=1)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True)
        featurizer.apply(split=1)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    # Using parallelism = 1 for deterministic behavior.
    labeler = Labeler(session, cands, parallelism=1)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            marginals_dict[relation] = generative_model(L_train[idx])

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 17
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    train_cands[idx],
                    F_train[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=100,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"transistor_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each rlation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=100,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            relation,
            test_preds,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

        # Dump CSV files for CE_V_MAX for digi-key analysis
        if relation == "ce_v_max":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

        # Dump CSV files for POLARITY for digi-key analysis
        if relation == "polarity":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")
Exemplo n.º 15
0
def test_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    PARALLEL = 1

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert docs[0].name == "112823"
    assert session.query(Part).count() == 58
    assert session.query(Temp).count() == 16
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    n_cands = session.query(PartTemp).count()

    # Featurization based on default feature library
    featurizer = Featurizer(session, [PartTemp])

    # Test that featurization default feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_feats = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_w_customized_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_textual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_tabular_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_structural_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_visual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands
Exemplo n.º 16
0
print("Sentences:", session.query(Sentence).count())
print("==============================")

# Getting all documents parsed by Snorkel
print("Getting documents and sentences...")
docs = session.query(Document).all()
#sents = session.query(Sentence).all()

from fonduer.candidates import CandidateExtractor, MentionExtractor, MentionNgrams
from fonduer.candidates.models import mention_subclass, candidate_subclass
from fonduer.candidates.matchers import RegexMatchSpan, Union, LambdaFunctionMatcher
from dataset_utils import price_match

# Defining ngrams for candidates
extraction_name = 'price'
ngrams = MentionNgrams(n_max=5)

# Define matchers
matchers = LambdaFunctionMatcher(func=price_match)

# Getting candidates
PriceMention = mention_subclass("PriceMention")
mention_extractor = MentionExtractor(
        session, [PriceMention], [ngrams], [matchers]
    )
mention_extractor.clear_all()
mention_extractor.apply(docs, parallelism=parallelism)
candidate_class = candidate_subclass("Price", [PriceMention])
candidate_extractor = CandidateExtractor(session, [candidate_class])

# Applying candidate extractors
Exemplo n.º 17
0
print(f'Number of documents: {session.query(Document).count()}')
#print(f'Number of sentences: {session.query(Sentence).count()}')
print("==============================")

# Getting all documents parsed by Snorkel
print("Getting documents and sentences...")
docs = session.query(Document).all()
#sents = session.query(Sentence).all()

from fonduer.candidates import CandidateExtractor, MentionExtractor, MentionNgrams
from fonduer.candidates.models import mention_subclass, candidate_subclass
from fonduer.candidates.matchers import RegexMatchSpan, Union

# Defining ngrams for candidates
extraction_name = "age"
age_ngrams = MentionNgrams(n_max=3)

# Define matchers
m = RegexMatchSpan(rgx=r'.*(I|He|She) (is|am) ^([0-9]{2})*')
p = RegexMatchSpan(rgx=r'.*(age|is|@|was) ^([0-9]{2})*')
q = RegexMatchSpan(rgx=r'.*(age:) ^([0-9]{2})*')
r = RegexMatchSpan(
    rgx=r'.*^([0-9]{2}) (yrs|years|year|yr|old|year-old|yr-old|Years|Year|Yr)*'
)
s = RegexMatchSpan(rgx=r'(^|\W)age\W{0,4}[1-9]\d(\W|$)')

# Union matchers and create candidate extractor
age_matchers = Union(m, p, r, q, s)

# Getting candidates
AgeMention = mention_subclass("AgeMention")
Exemplo n.º 18
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    # Setup initial configuration
    if gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    featurizer = Featurizer(session, cands)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True, parallelism=parallel)
        featurizer.apply(split=1, parallelism=parallel)
        featurizer.apply(split=2, parallelism=parallel)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb"))
        pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb"))
        pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb"))
    else:
        F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb"))
        F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb"))
        F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb"))
    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    labeler = Labeler(session, cands)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs, parallelism=parallel)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()
    if stg_temp_min:
        relation = "stg_temp_min"
        idx = rel_list.index(relation)
        marginals_stg_temp_min = generative_model(L_train[idx])
        disc_model_stg_temp_min = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals_stg_temp_min,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(
            relation,
            disc_model_stg_temp_min,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if stg_temp_max:
        relation = "stg_temp_max"
        idx = rel_list.index(relation)
        marginals_stg_temp_max = generative_model(L_train[idx])
        disc_model_stg_temp_max = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals_stg_temp_max,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(
            relation,
            disc_model_stg_temp_max,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if polarity:
        relation = "polarity"
        idx = rel_list.index(relation)
        marginals_polarity = generative_model(L_train[idx])
        disc_model_polarity = discriminative_model(train_cands[idx],
                                                   F_train[idx],
                                                   marginals_polarity,
                                                   n_epochs=100,
                                                   gpu=gpu)
        best_result, best_b = scoring(
            relation,
            disc_model_polarity,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if ce_v_max:
        relation = "ce_v_max"
        idx = rel_list.index(relation)

        # Can be uncommented for use in debugging labeling functions
        #  logger.info("Updating labeling function summary...")
        #  keys = labeler.get_keys()
        #  logger.info("Summary for train set labeling functions:")
        #  df = analysis.lf_summary(L_train[idx], lf_names=keys)
        #  logger.info(f"\n{df.to_string()}")
        #
        #  logger.info("Summary for dev set labeling functions:")
        #  df = analysis.lf_summary(
        #      L_dev[idx],
        #      lf_names=keys,
        #      Y=L_dev_gold[idx].todense().reshape(-1).tolist()[0],
        #  )
        #  logger.info(f"\n{df.to_string()}")
        #
        #  logger.info("Summary for test set labeling functions:")
        #  df = analysis.lf_summary(
        #      L_test[idx],
        #      lf_names=keys,
        #      Y=L_test_gold[idx].todense().reshape(-1).tolist()[0],
        #  )
        #  logger.info(f"\n{df.to_string()}")

        marginals_ce_v_max = generative_model(L_train[idx])
        disc_model_ce_v_max = discriminative_model(train_cands[idx],
                                                   F_train[idx],
                                                   marginals_ce_v_max,
                                                   n_epochs=100,
                                                   gpu=gpu)

        # Can be uncommented to view score on development set
        #  best_result, best_b = scoring(
        #      relation,
        #      disc_model_ce_v_max,
        #      dev_cands[idx],
        #      dev_docs,
        #      F_dev[idx],
        #      parts_by_doc,
        #      num=100,
        #  )

        best_result, best_b = scoring(
            relation,
            disc_model_ce_v_max,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")

    # Dump CSV files for CE_V_MAX for digi-key analysis
    if ce_v_max:
        relation = "ce_v_max"
        idx = rel_list.index(relation)
        Y_prob = disc_model_ce_v_max.marginals((test_cands[idx], F_test[idx]))
        dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
        Y_prob = disc_model_ce_v_max.marginals((dev_cands[idx], F_dev[idx]))
        dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

    # Dump CSV files for POLARITY for digi-key analysis
    if polarity:
        relation = "polarity"
        idx = rel_list.index(relation)
        Y_prob = disc_model_polarity.marginals((test_cands[idx], F_test[idx]))
        dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
        Y_prob = disc_model_polarity.marginals((dev_cands[idx], F_dev[idx]))
        dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")
Exemplo n.º 19
0
]

stations_mapping_dict = {
    k: station_list
    for station_list in stations for k in station_list
}
stations_list = [s for station_list in stations for s in station_list]
station_rgx = '|'.join(stations_list)

# 1.) Mention classes
Station = mention_subclass("Station")
Price = mention_subclass("Price")

# 2.) Mention spaces
station_ngrams = MentionNgrams(
    n_max=4, split_tokens=[" ", "_", "\.",
                           "%"])  # StationMentionSpace(n_max=4) #
price_ngrams = MentionNgrams(n_max=1)

# 3.) Matcher functions
station_matcher = RegexMatchFull(
    rgx=station_rgx,
    ignore_case=True,
    # search=True,
    # full_match=False,
    # longest_match_only=False,
)  # DictionaryMatch(d=stations_list)
price_matcher = RegexMatchSpan(rgx=r"\d{1,4}(\.\d{1,5})",
                               longest_match_only=True)

# 4.) Candidate classes
Exemplo n.º 20
0
def main(
    conn_string,
    gain=False,
    current=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=8,
    log_dir="logs",
    verbose=False,
):
    # Setup initial configuration
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if gain:
        rel_list.append("gain")

    if current:
        rel_list.append("current")

    logger.info(f"=" * 30)
    logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}")

    session = Meta.init(conn_string).Session()

    # Parsing
    start = timer()
    logger.info(f"Starting parsing...")
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    logger.debug(f"Done")
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of Documents: {len(docs)}")
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    if gain:
        Gain = mention_subclass("Gain")
        gain_matcher = get_gain_matcher()
        gain_ngrams = MentionNgrams(n_max=2)
        mentions.append(Gain)
        ngrams.append(gain_ngrams)
        matchers.append(gain_matcher)

    if current:
        Current = mention_subclass("SupplyCurrent")
        current_matcher = get_supply_current_matcher()
        current_ngrams = MentionNgramsCurrent(n_max=3)
        mentions.append(Current)
        ngrams.append(current_ngrams)
        matchers.append(current_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")

    if gain:
        logger.info(f"Total Gain: {session.query(Gain).count()}")

    if current:
        logger.info(f"Total Current: {session.query(Current).count()}")

    cand_classes = []
    if gain:
        GainCand = candidate_subclass("GainCand", [Gain])
        cand_classes.append(GainCand)
    if current:
        CurrentCand = candidate_subclass("CurrentCand", [Current])
        cand_classes.append(CurrentCand)

    candidate_extractor = CandidateExtractor(session, cand_classes)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    logger.info(
        f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}")
    logger.info(
        f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}")
    logger.info(
        f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}")

    logger.info("Done w/ candidate extraction.")
    end = timer()
    logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}")

    # First, check total recall
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[0], is_gain=True),
    #      corpus=dev_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[0], is_gain=True),
    #      corpus=test_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[1], is_gain=False),
    #      corpus=dev_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[1], is_gain=False),
    #      corpus=test_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")

    start = timer()

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cand_classes, parallelism=1)

    if first_time:
        logger.info("Starting featurizer...")
        # Set feature space based on dev set, which we use for training rather
        # than the large train set.
        featurizer.apply(split=1, train=True)
        featurizer.apply(split=0)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    # Serialize feature matrices on first run
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    start = timer()
    logger.info("Labeling training data...")
    #  labeler = Labeler(session, cand_classes)
    #  lfs = []
    #  if gain:
    #      lfs.append(gain_lfs)
    #
    #  if current:
    #      lfs.append(current_lfs)
    #
    #  if first_time:
    #      logger.info("Applying LFs...")
    #      labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
    #  elif re_label:
    #      logger.info("Re-applying LFs...")
    #      labeler.update(split=0, lfs=lfs, parallelism=parallel)
    #
    #  logger.info("Done...")

    #  logger.info("Getting label matrices...")
    #  L_train = labeler.get_label_matrices(train_cands)
    #  logger.info("Done...")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            # Manually create marginals from human annotations
            marginal = []
            dev_gold_entities = get_gold_set(is_gain=(relation == "gain"))
            for c in dev_cands[idx]:
                flag = False
                for entity in cand_to_entity(c, is_gain=(relation == "gain")):
                    if entity in dev_gold_entities:
                        flag = True

                if flag:
                    marginal.append([0.0, 1.0])
                else:
                    marginal.append([1.0, 0.0])

            marginals_dict[relation] = np.array(marginal)

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(
        f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 30
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 500,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.005
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        # only uses dev set as training data, with human annotations
        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    dev_cands[idx],
                    F_dev[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=256,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"opamp_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each relation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=256,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            test_preds,
            test_cands[idx],
            test_docs,
            is_gain=(relation == "gain"),
            num=100,
        )

        # Dump CSV files for analysis
        if relation == "gain":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "gain_test_probs.csv",
                            is_gain=True)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "gain_dev_probs.csv",
                            is_gain=True)

        if relation == "current":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=False)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "current_test_probs.csv",
                            is_gain=False)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "current_dev_probs.csv",
                            is_gain=False)

    end = timer()
    logger.warning(
        f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}"
    )
Exemplo n.º 21
0
def main(
    conn_string,
    gain=False,
    current=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    gpu=None,
    parallel=8,
    log_dir="logs",
    verbose=False,
):
    # Setup initial configuration
    if gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if gain:
        rel_list.append("gain")

    if current:
        rel_list.append("current")

    logger.info(f"=" * 30)
    logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}")

    session = Meta.init(conn_string).Session()

    # Parsing
    start = timer()
    logger.info(f"Starting parsing...")
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    logger.debug(f"Done")
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of Documents: {len(docs)}")
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    if gain:
        Gain = mention_subclass("Gain")
        gain_matcher = get_gain_matcher()
        gain_ngrams = MentionNgrams(n_max=2)
        mentions.append(Gain)
        ngrams.append(gain_ngrams)
        matchers.append(gain_matcher)

    if current:
        Current = mention_subclass("SupplyCurrent")
        current_matcher = get_supply_current_matcher()
        current_ngrams = MentionNgramsCurrent(n_max=3)
        mentions.append(Current)
        ngrams.append(current_ngrams)
        matchers.append(current_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")

    if gain:
        logger.info(f"Total Gain: {session.query(Gain).count()}")

    if current:
        logger.info(f"Total Current: {session.query(Current).count()}")

    cand_classes = []
    if gain:
        GainCand = candidate_subclass("GainCand", [Gain])
        cand_classes.append(GainCand)
    if current:
        CurrentCand = candidate_subclass("CurrentCand", [Current])
        cand_classes.append(CurrentCand)

    candidate_extractor = CandidateExtractor(session, cand_classes)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)
    logger.info(
        f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}")
    logger.info(
        f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}")
    logger.info(
        f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}")

    logger.info("Done w/ candidate extraction.")
    end = timer()
    logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}")

    # First, check total recall
    #  result = entity_level_scores(dev_cands[0], corpus=dev_docs)
    #  logger.info(f"Gain Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(test_cands[0], corpus=test_docs)
    #  logger.info(f"Gain Total Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #
    #  result = entity_level_scores(dev_cands[1], corpus=dev_docs, is_gain=False)
    #  logger.info(f"Current Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(test_cands[1], corpus=test_docs, is_gain=False)
    #  logger.info(f"Current Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")

    start = timer()
    featurizer = Featurizer(session, cand_classes)

    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True, parallelism=parallel)
        featurizer.apply(split=1, parallelism=parallel)
        featurizer.apply(split=2, parallelism=parallel)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    # Serialize feature matrices on first run
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb"))
        pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb"))
        pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb"))
    else:
        F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb"))
        F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb"))
        F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb"))
    logger.info("Done.")

    start = timer()
    logger.info("Labeling training data...")
    labeler = Labeler(session, cand_classes)
    lfs = []
    if gain:
        lfs.append(gain_lfs)

    if current:
        lfs.append(current_lfs)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
    elif re_label:
        logger.info("Re-applying LFs...")
        labeler.update(split=0, lfs=lfs, parallelism=parallel)

    logger.info("Done...")

    logger.info("Getting label matrices...")
    L_train = labeler.get_label_matrices(train_cands)
    logger.info("Done...")

    end = timer()
    logger.warning(
        f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}")

    if gain:
        relation = "gain"
        idx = rel_list.index(relation)

        logger.info("Score Gain.")
        dev_gold_entities = get_gold_set(is_gain=True)
        L_dev_gt = []
        for c in dev_cands[idx]:
            flag = FALSE
            for entity in cand_to_entity(c, is_gain=True):
                if entity in dev_gold_entities:
                    flag = TRUE
            L_dev_gt.append(flag)

        marginals = generative_model(L_train[idx])
        disc_models = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals,
            X_dev=(dev_cands[idx], F_dev[idx]),
            Y_dev=L_dev_gt,
            n_epochs=500,
            gpu=gpu,
        )
        best_result, best_b = scoring(disc_models,
                                      test_cands[idx],
                                      test_docs,
                                      F_test[idx],
                                      num=50)

        print_scores(relation, best_result, best_b)

        logger.info("Output CSV files for Opo and Digi-key Analysis.")
        Y_prob = disc_models.marginals((train_cands[idx], F_train[idx]))
        output_csv(train_cands[idx], Y_prob, is_gain=True)

        Y_prob = disc_models.marginals((test_cands[idx], F_test[idx]))
        output_csv(test_cands[idx], Y_prob, is_gain=True, append=True)
        dump_candidates(test_cands[idx],
                        Y_prob,
                        "gain_test_probs.csv",
                        is_gain=True)

        Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx]))
        output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True)
        dump_candidates(dev_cands[idx],
                        Y_prob,
                        "gain_dev_probs.csv",
                        is_gain=True)

    if current:
        relation = "current"
        idx = rel_list.index(relation)

        logger.info("Score Current.")
        dev_gold_entities = get_gold_set(is_gain=False)
        L_dev_gt = []
        for c in dev_cands[idx]:
            flag = FALSE
            for entity in cand_to_entity(c, is_gain=False):
                if entity in dev_gold_entities:
                    flag = TRUE
            L_dev_gt.append(flag)

        marginals = generative_model(L_train[idx])

        disc_models = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals,
            X_dev=(dev_cands[idx], F_dev[idx]),
            Y_dev=L_dev_gt,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(disc_models,
                                      test_cands[idx],
                                      test_docs,
                                      F_test[idx],
                                      is_gain=False,
                                      num=50)

        print_scores(relation, best_result, best_b)

        logger.info("Output CSV files for Opo and Digi-key Analysis.")
        # Dump CSV files for digi-key analysis and Opo comparison
        Y_prob = disc_models.marginals((train_cands[idx], F_train[idx]))
        output_csv(train_cands[idx], Y_prob, is_gain=False)

        Y_prob = disc_models.marginals((test_cands[idx], F_test[idx]))
        output_csv(test_cands[idx], Y_prob, is_gain=False, append=True)
        dump_candidates(test_cands[idx],
                        Y_prob,
                        "current_test_probs.csv",
                        is_gain=False)

        Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx]))
        output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True)
        dump_candidates(dev_cands[idx],
                        Y_prob,
                        "current_dev_probs.csv",
                        is_gain=False)

    end = timer()
    logger.warning(
        f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}"
    )
Exemplo n.º 22
0
# In[12]:

birth_place_in_labeled_row_matcher = LambdaFunctionMatcher(
    func=is_in_birthplace_table_row)
birth_place_in_labeled_row_matcher.longest_match_only = False
birth_place_no_commas_matcher = LambdaFunctionMatcher(
    func=no_commas_in_birth_place)
birth_place_left_aligned_matcher = LambdaFunctionMatcher(
    func=birthplace_left_aligned_to_punctuation)

place_of_birth_matcher = Intersect(
    birth_place_in_labeled_row_matcher,
    birth_place_no_commas_matcher,
    birth_place_left_aligned_matcher,
)

from fonduer.candidates import MentionNgrams

presname_ngrams = MentionNgrams(n_max=4, n_min=2)
placeofbirth_ngrams = MentionNgrams(n_max=3)

from fonduer.candidates.models import candidate_subclass

PresidentnamePlaceofbirth = candidate_subclass("PresidentnamePlaceofbirth",
                                               [Presidentname, Placeofbirth])

mention_classes = [Presidentname, Placeofbirth]
mention_spaces = [presname_ngrams, placeofbirth_ngrams]
matchers = [president_name_matcher, place_of_birth_matcher]
candidate_classes = [PresidentnamePlaceofbirth]
Exemplo n.º 23
0
def test_unary_relation_feature_extraction():
    """Test extracting unary candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")

    mention_extractor_udf = MentionExtractorUDF([Part], [part_ngrams],
                                                [part_matcher])
    doc = mention_extractor_udf.apply(doc)

    assert doc.name == "112823"
    assert len(doc.parts) == 62
    part = doc.parts[0]
    logger.info(f"Part: {part.context}")

    # Candidate Extraction
    PartRel = candidate_subclass("PartRel", [Part])

    candidate_extractor_udf = CandidateExtractorUDF([PartRel], None, False,
                                                    False, True)
    doc = candidate_extractor_udf.apply(doc, split=0)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartRel], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartRel],
                                   feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)
Exemplo n.º 24
0
def test_multinary_relation_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    # Parsing
    doc = parse_doc(docs_path, "112823", pdf_path)
    assert len(doc.sentences) == 799

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)
    volt_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor_udf = MentionExtractorUDF(
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.parts) == 62
    assert len(doc.temps) == 16
    assert len(doc.volts) == 33
    part = doc.parts[0]
    temp = doc.temps[0]
    volt = doc.volts[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")
    logger.info(f"Volt: {volt.context}")

    # Candidate Extraction
    PartTempVolt = candidate_subclass("PartTempVolt", [Part, Temp, Volt])

    candidate_extractor_udf = CandidateExtractorUDF([PartTempVolt], None,
                                                    False, False, True)

    doc = candidate_extractor_udf.apply(doc, split=0)

    # Manually set id as it is not set automatically b/c a database is not used.
    i = 0
    for cand in doc.part_temp_volts:
        cand.id = i
        i = i + 1

    n_cands = len(doc.part_temp_volts)

    # Featurization based on default feature library
    featurizer_udf = FeaturizerUDF([PartTempVolt], FeatureExtractor())

    # Test that featurization default feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_feats = len(key_set)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_default_w_customized_features = len(key_set)

    # Example spurious feature extractor
    def bad_feat_ext(candidates):
        raise RuntimeError()

    # Featurization with a spurious feature extractor
    feature_extractors = FeatureExtractor(
        customize_feature_funcs=[bad_feat_ext])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    logger.info("Featurizing with a spurious feature extractor...")
    with pytest.raises(RuntimeError):
        features = featurizer_udf.apply(doc)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_textual_features = len(key_set)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_tabular_features = len(key_set)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_structural_features = len(key_set)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer_udf = FeaturizerUDF([PartTempVolt],
                                   feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    features_list = featurizer_udf.apply(doc)
    features = itertools.chain.from_iterable(features_list)
    key_set = set([key for feature in features for key in feature["keys"]])
    n_visual_features = len(key_set)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands
Exemplo n.º 25
0
def test_unary_relation_feature_extraction():
    """Test extracting unary candidates from mentions from documents."""
    PARALLEL = 1

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")

    mention_extractor = MentionExtractor(session, [Part], [part_ngrams],
                                         [part_matcher])
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert docs[0].name == "112823"
    assert session.query(Part).count() == 58
    part = session.query(Part).order_by(Part.id).all()[0]
    logger.info(f"Part: {part.context}")

    # Candidate Extraction
    PartRel = candidate_subclass("PartRel", [Part])

    candidate_extractor = CandidateExtractor(session, [PartRel])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    # Featurization based on default feature library
    featurizer = Featurizer(session, [PartRel])

    # Test that featurization default feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_feats = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_textual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_tabular_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_structural_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_visual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)