Exemple #1
0
def apply_labellling_functions(featurizer_output):
    session = featurizer_output['session']
    cands = featurizer_output['candidate_variable']
    labeler = Labeler(session, cands)
    labeler.apply(lfs=[lfs], train=True, parallelism=config.PARALLEL)

    train_cands = []
    train_cands.append(
        session.query(featurizer_output['candidate_variable'][0]).all())
    L_train = labeler.get_label_matrices(train_cands)

    gen_model = LabelModel(k=2)
    gen_model.train_model(L_train[0], n_epochs=300, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    featurizer_output['train_marginals'] = train_marginals
    return featurizer_output
Exemple #2
0
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    PARALLEL = 4

    max_docs = 12

    fonduer.init_logging(
        log_dir="log_folder",
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (
        len(
            mention_extractor.get_mentions(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 70
    )

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]
    )

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (
        len(
            candidate_extractor.get_candidates(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 1432
    )

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey)
        .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]")
        .one()
        .candidate_classes
    ) == {"part_temp", "part_volt"}
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]"
    ).one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(
        ["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt]
    )
    assert set(
        session.query(FeatureKey)
        .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]")
        .one()
        .candidate_classes
    ) == {"part_temp", "part_volt"}
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    gen_model = LabelModel()
    gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]),
        train_marginals,
        X_dev=(train_cands[0], F_train[0]),
        Y_dev=L_train_gold[0].reshape(-1),
        b=0.6,
        pos_label=TRUE,
        n_epochs=5,
        lr=0.001,
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    gen_model = LabelModel()
    gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    disc_model = LSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse Logistic Regression
    disc_model = SparseLogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse LSTM
    disc_model = SparseLSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Evaluate mention level scores
    L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")
    Y_test = L_test_gold[0].reshape(-1)

    scores = disc_model.score((test_cands[0], F_test[0]), Y_test, b=0.6, pos_label=TRUE)

    logger.info(scores)

    assert scores["f1"] > 0.6
Exemple #3
0
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 12

    fonduer.init_logging(
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (len(
        mention_extractor.get_mentions(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 70)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (len(
        candidate_extractor.get_candidates(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 1432)

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                           candidate_classes=[PartTemp, PartVolt])
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    # Collect word counter
    word_counter = collect_word_counter(train_cands)

    emmental.init(fonduer.Meta.log_path)

    # Training config
    config = {
        "meta_config": {
            "verbose": False
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0],
                               emb_layer.word2id, 2),
        split="test",
        batch_size=100,
        shuffle=False,
    )

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    valid_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            np.argmax(train_marginals, axis=1),
            train_idxs,
        ),
        split="valid",
        batch_size=100,
        shuffle=False,
    )

    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7
def test_incremental(caplog):
    """Run an end-to-end test on incremental additions."""
    caplog.set_level(logging.INFO)

    PARALLEL = 1

    max_docs = 1

    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/dtc114w.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()
    last_docs = corpus_parser.get_documents()

    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 8

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70
    assert session.query(Candidate).count() == 70

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 70

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 70
    assert session.query(FeatureKey).count() == 512

    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (70, 512)
    assert len(featurizer.get_keys()) == 512

    # Test Dropping FeatureKey
    featurizer.drop_keys(["CORE_e1_LENGTH_1"])
    assert session.query(FeatureKey).count() == 512

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == 70

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (70, 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)

    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1502

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == 1502
    assert session.query(FeatureKey).count() == 2573
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (1502, 2573)
    assert len(featurizer.get_keys()) == 2573

    # Update Labels
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 1502
    assert session.query(LabelKey).count() == 6
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (1502, 6)

    # Test clear
    featurizer.clear(train=True)
    assert session.query(FeatureKey).count() == 0
Exemple #5
0
def test_e2e(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 12

    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info("Sentences: {}".format(num_sentences))

    # Divide into test and train
    docs = corpus_parser.get_documents()
    ld = len(docs)
    assert ld == len(corpus_parser.get_last_documents())
    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 147
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (len(
        mention_extractor.get_mentions(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 70)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3684
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 72
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 448
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3684
    assert (len(
        candidate_extractor.get_candidates(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 1496)

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 225
    assert session.query(FeatureKey).count() == 1179

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NFP NN NFP]"])
    assert session.query(FeatureKey).count() == 1178

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1178
    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1177
    session.query(Feature).delete()
    session.query(FeatureKey).delete()

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6669
    assert session.query(FeatureKey).count() == 4161
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3684, 4161)
    assert F_train[1].shape == (2985, 4161)
    assert len(featurizer.get_keys()) == 4161

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6894
    assert session.query(FeatureKey).count() == 4161
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (72, 4161)
    assert F_dev[1].shape == (153, 4161)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8486
    assert session.query(FeatureKey).count() == 4161
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (448, 4161)
    assert F_test[1].shape == (1144, 4161)

    gold_file = "tests/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session,
                         PartTemp,
                         gold_file,
                         ATTRIBUTE,
                         annotator_name="gold")
    assert session.query(GoldLabel).count() == 4204
    load_hardware_labels(session,
                         PartVolt,
                         gold_file,
                         ATTRIBUTE,
                         annotator_name="gold")
    assert session.query(GoldLabel).count() == 8486

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    labeler = Labeler(session, [PartTemp, PartVolt])

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs, ce_v_max_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == 6669
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3684, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3684, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3684, 1)

    gen_model = LabelModel(k=2)
    gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=20,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == 6669
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3684, 16)

    gen_model = LabelModel(k=2)
    gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=20,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7

    # Testing LSTM
    disc_model = LSTM()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=5,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7

    # Testing Sparse Logistic Regression
    disc_model = SparseLogisticRegression()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=20,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7

    # Testing Sparse LSTM
    disc_model = SparseLSTM()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=5,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
Exemple #6
0
def test_incremental():
    """Run an end-to-end test on incremental additions."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 1

    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/dtc114w.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()
    last_docs = corpus_parser.get_documents()

    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 8

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70
    assert session.query(Candidate).count() == 70

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 70

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 70
    assert session.query(FeatureKey).count() == 512

    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (70, 512)
    assert len(featurizer.get_keys()) == 512

    # Test Dropping FeatureKey
    featurizer.drop_keys(["CORE_e1_LENGTH_1"])
    assert session.query(FeatureKey).count() == 512

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == 70

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (70, 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)
    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Test if existing mentions are skipped.
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)
    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1502

    # Test if existing candidates are skipped.
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1502

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == 1502
    assert session.query(FeatureKey).count() == 2573
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (1502, 2573)
    assert len(featurizer.get_keys()) == 2573

    # Update LF_storage_row. Now it always returns ABSTAIN.
    @labeling_function(name="LF_storage_row")
    def LF_storage_row_updated(c):
        return ABSTAIN

    stg_temp_lfs = [
        LF_storage_row_updated,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    # Update Labels
    labeler.update(docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 1502
    # Only 5 because LF_storage_row doesn't apply to any doc (always ABSTAIN)
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (1502, 5)

    # Test clear
    featurizer.clear(train=True)
    assert session.query(FeatureKey).count() == 0
Exemple #7
0
def test_incremental(caplog):
    """Run an end-to-end test on incremental additions."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parsing
    if os.name == "posix":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 1

    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/dtc114w.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 9

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 78

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 78

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 78
    assert session.query(FeatureKey).count() == 496
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (78, 496)
    assert len(featurizer.get_keys()) == 496

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL)
    assert session.query(Label).count() == 78

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (78, 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)

    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 33

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1574

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == 1574
    assert session.query(FeatureKey).count() == 2425
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (1574, 2425)
    assert len(featurizer.get_keys()) == 2425

    # Update Labels
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 1574
    assert session.query(LabelKey).count() == 6
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (1574, 6)
Exemple #8
0
featurizer.apply(split=0, train=True, parallelism=PARALLEL)
F_train = featurizer.get_feature_matrices(train_cands)

from wiki_table_utils import load_president_gold_labels

gold_file = "data/president_tutorial_gold.csv"
load_president_gold_labels(
    session, candidate_classes, gold_file, annotator_name="gold"
)

from lfconfig import president_name_pob_lfs, TRUE

from fonduer.supervision import Labeler

labeler = Labeler(session, candidate_classes)
labeler.apply(split=0, lfs=[president_name_pob_lfs], train=True, parallelism=PARALLEL)
L_train = labeler.get_label_matrices(train_cands)

L_gold_train = labeler.get_gold_labels(train_cands, annotator="gold")

from metal import analysis

analysis.lf_summary(
    L_train[0],
    lf_names=labeler.get_keys(),
    Y=L_gold_train[0].todense().reshape(-1).tolist()[0],
)

from metal.label_model import LabelModel

gen_model = LabelModel(k=2)
Exemple #9
0
def main(
    conn_string,
    gain=False,
    current=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    gpu=None,
    parallel=8,
    log_dir="logs",
    verbose=False,
):
    # Setup initial configuration
    if gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if gain:
        rel_list.append("gain")

    if current:
        rel_list.append("current")

    logger.info(f"=" * 30)
    logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}")

    session = Meta.init(conn_string).Session()

    # Parsing
    start = timer()
    logger.info(f"Starting parsing...")
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    logger.debug(f"Done")
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of Documents: {len(docs)}")
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    if gain:
        Gain = mention_subclass("Gain")
        gain_matcher = get_gain_matcher()
        gain_ngrams = MentionNgrams(n_max=2)
        mentions.append(Gain)
        ngrams.append(gain_ngrams)
        matchers.append(gain_matcher)

    if current:
        Current = mention_subclass("SupplyCurrent")
        current_matcher = get_supply_current_matcher()
        current_ngrams = MentionNgramsCurrent(n_max=3)
        mentions.append(Current)
        ngrams.append(current_ngrams)
        matchers.append(current_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")

    if gain:
        logger.info(f"Total Gain: {session.query(Gain).count()}")

    if current:
        logger.info(f"Total Current: {session.query(Current).count()}")

    cand_classes = []
    if gain:
        GainCand = candidate_subclass("GainCand", [Gain])
        cand_classes.append(GainCand)
    if current:
        CurrentCand = candidate_subclass("CurrentCand", [Current])
        cand_classes.append(CurrentCand)

    candidate_extractor = CandidateExtractor(session, cand_classes)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)
    logger.info(
        f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}")
    logger.info(
        f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}")
    logger.info(
        f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}")

    logger.info("Done w/ candidate extraction.")
    end = timer()
    logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}")

    # First, check total recall
    #  result = entity_level_scores(dev_cands[0], corpus=dev_docs)
    #  logger.info(f"Gain Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(test_cands[0], corpus=test_docs)
    #  logger.info(f"Gain Total Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #
    #  result = entity_level_scores(dev_cands[1], corpus=dev_docs, is_gain=False)
    #  logger.info(f"Current Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(test_cands[1], corpus=test_docs, is_gain=False)
    #  logger.info(f"Current Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")

    start = timer()
    featurizer = Featurizer(session, cand_classes)

    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True, parallelism=parallel)
        featurizer.apply(split=1, parallelism=parallel)
        featurizer.apply(split=2, parallelism=parallel)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    # Serialize feature matrices on first run
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb"))
        pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb"))
        pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb"))
    else:
        F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb"))
        F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb"))
        F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb"))
    logger.info("Done.")

    start = timer()
    logger.info("Labeling training data...")
    labeler = Labeler(session, cand_classes)
    lfs = []
    if gain:
        lfs.append(gain_lfs)

    if current:
        lfs.append(current_lfs)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
    elif re_label:
        logger.info("Re-applying LFs...")
        labeler.update(split=0, lfs=lfs, parallelism=parallel)

    logger.info("Done...")

    logger.info("Getting label matrices...")
    L_train = labeler.get_label_matrices(train_cands)
    logger.info("Done...")

    end = timer()
    logger.warning(
        f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}")

    if gain:
        relation = "gain"
        idx = rel_list.index(relation)

        logger.info("Score Gain.")
        dev_gold_entities = get_gold_set(is_gain=True)
        L_dev_gt = []
        for c in dev_cands[idx]:
            flag = FALSE
            for entity in cand_to_entity(c, is_gain=True):
                if entity in dev_gold_entities:
                    flag = TRUE
            L_dev_gt.append(flag)

        marginals = generative_model(L_train[idx])
        disc_models = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals,
            X_dev=(dev_cands[idx], F_dev[idx]),
            Y_dev=L_dev_gt,
            n_epochs=500,
            gpu=gpu,
        )
        best_result, best_b = scoring(disc_models,
                                      test_cands[idx],
                                      test_docs,
                                      F_test[idx],
                                      num=50)

        print_scores(relation, best_result, best_b)

        logger.info("Output CSV files for Opo and Digi-key Analysis.")
        Y_prob = disc_models.marginals((train_cands[idx], F_train[idx]))
        output_csv(train_cands[idx], Y_prob, is_gain=True)

        Y_prob = disc_models.marginals((test_cands[idx], F_test[idx]))
        output_csv(test_cands[idx], Y_prob, is_gain=True, append=True)
        dump_candidates(test_cands[idx],
                        Y_prob,
                        "gain_test_probs.csv",
                        is_gain=True)

        Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx]))
        output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True)
        dump_candidates(dev_cands[idx],
                        Y_prob,
                        "gain_dev_probs.csv",
                        is_gain=True)

    if current:
        relation = "current"
        idx = rel_list.index(relation)

        logger.info("Score Current.")
        dev_gold_entities = get_gold_set(is_gain=False)
        L_dev_gt = []
        for c in dev_cands[idx]:
            flag = FALSE
            for entity in cand_to_entity(c, is_gain=False):
                if entity in dev_gold_entities:
                    flag = TRUE
            L_dev_gt.append(flag)

        marginals = generative_model(L_train[idx])

        disc_models = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals,
            X_dev=(dev_cands[idx], F_dev[idx]),
            Y_dev=L_dev_gt,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(disc_models,
                                      test_cands[idx],
                                      test_docs,
                                      F_test[idx],
                                      is_gain=False,
                                      num=50)

        print_scores(relation, best_result, best_b)

        logger.info("Output CSV files for Opo and Digi-key Analysis.")
        # Dump CSV files for digi-key analysis and Opo comparison
        Y_prob = disc_models.marginals((train_cands[idx], F_train[idx]))
        output_csv(train_cands[idx], Y_prob, is_gain=False)

        Y_prob = disc_models.marginals((test_cands[idx], F_test[idx]))
        output_csv(test_cands[idx], Y_prob, is_gain=False, append=True)
        dump_candidates(test_cands[idx],
                        Y_prob,
                        "current_test_probs.csv",
                        is_gain=False)

        Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx]))
        output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True)
        dump_candidates(dev_cands[idx],
                        Y_prob,
                        "current_dev_probs.csv",
                        is_gain=False)

    end = timer()
    logger.warning(
        f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}"
    )
Exemple #10
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cands, parallelism=1)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True)
        featurizer.apply(split=1)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    # Using parallelism = 1 for deterministic behavior.
    labeler = Labeler(session, cands, parallelism=1)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            marginals_dict[relation] = generative_model(L_train[idx])

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 17
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    train_cands[idx],
                    F_train[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=100,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"transistor_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each rlation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=100,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            relation,
            test_preds,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

        # Dump CSV files for CE_V_MAX for digi-key analysis
        if relation == "ce_v_max":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

        # Dump CSV files for POLARITY for digi-key analysis
        if relation == "polarity":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")
    def run_labeling_functions(cands):
        ABSTAIN = -1
        FALSE = 0
        TRUE = 1
        # Extract candidates
        train_cands = cands[0]
        dev_cands = cands[1]
        test_cands = cands[2] 

        @labeling_function()
        def LF_other_station_table(c):
            station_span = c.station.context.get_span().lower()
            neighbour_cells = get_neighbor_cell_ngrams_own(c.price, dist=100, directions=True, n_max = 4, absolute = True)
            up_cells = [x for x in neighbour_cells if len(x) > 1 and x[1] == 'DOWN' and x[0] in stations_list]
            # No station name in upper cells
            if (len(up_cells) == 0):
                return ABSTAIN
            # Check if the next upper aligned station-span corresponds to the candidate span (or equivalents)
            closest_header = up_cells[len(up_cells)-1]
            return TRUE if closest_header[0] in stations_mapping_dict[station_span] else FALSE

        @labeling_function()
        def LF_station_non_meta_tag(c):
            html_tags = get_ancestor_tag_names(c.station)
            return FALSE if ('head' in html_tags and 'title' in html_tags) else ABSTAIN

        # Basic constraint for the price LFs to be true -> no wrong station (increase accuracy)
        def base(c):
            return (
                LF_station_non_meta_tag(c) != 0 and 
                LF_other_station_table(c) != 0 and 
                LF_off_peak_head(c) != 0 and
                LF_purchases(c)
            )

        # 2.) Create labeling functions 
        @labeling_function()
        def LF_on_peak_head(c):
            return TRUE if 'on peak' in get_aligned_ngrams(c.price, n_min=2, n_max=2)  and base(c) else ABSTAIN

        @labeling_function()
        def LF_off_peak_head(c):
            return FALSE if 'off peak' in get_aligned_ngrams(c.price, n_min=2, n_max=2) else ABSTAIN

        @labeling_function()
        def LF_price_range(c):
            price = float(c.price.context.get_span())
            return TRUE if price > 0 and price < 1000 and base(c) else FALSE

        @labeling_function()
        def LF_price_head(c):
            return TRUE if 'price' in get_aligned_ngrams(c.price) and base(c) else ABSTAIN

        @labeling_function()
        def LF_firm_head(c):
            return TRUE if 'firm' in get_aligned_ngrams(c.price)and base(c) else ABSTAIN

        @labeling_function()
        def LF_dollar_to_left(c):
            return TRUE if '$' in get_left_ngrams(c.price, window=2) and base(c) else ABSTAIN

        @labeling_function()
        def LF_purchases(c):
            return FALSE if 'purchases' in get_aligned_ngrams(c.price, n_min=1) else ABSTAIN

        station_price_lfs = [
            LF_other_station_table,
            LF_station_non_meta_tag,

            # indicator
            LF_price_range,

            # negative indicators
            LF_off_peak_head,
            LF_purchases,

            # positive indicators
            LF_on_peak_head,    
            LF_price_head,
            LF_firm_head,
            LF_dollar_to_left,
        ]

        # 3.) Apply the LFs on the training set
        labeler = Labeler(session, [StationPrice])
        labeler.apply(split=0, lfs=[station_price_lfs], train=True, clear=True, parallelism=PARALLEL)
        L_train = labeler.get_label_matrices(train_cands)

        # Check that LFs are all applied (avoid crash)
        applied_lfs = L_train[0].shape[1]
        has_non_applied = applied_lfs != len(station_price_lfs)
        print(f"Labeling functions on train_cands not ABSTAIN: {applied_lfs} (/{len(station_price_lfs)})")

        if (has_non_applied):
            applied_lfs = get_applied_lfs(session)
            non_applied_lfs = [l.name for l in station_price_lfs if l.name not in applied_lfs]
            print(f"Labling functions {non_applied_lfs} are not applied.")
            station_price_lfs = [l for l in station_price_lfs if l.name in applied_lfs]

        # 4.) Evaluate their accuracy
        L_gold_train = labeler.get_gold_labels(train_cands, annotator='gold')
        # Sort LFs for LFAnalysis because LFAnalysis does not sort LFs,
        # while columns of L_train are sorted alphabetically already.
        sorted_lfs = sorted(station_price_lfs, key=lambda lf: lf.name)
        LFAnalysis(L=L_train[0], lfs=sorted_lfs).lf_summary(Y=L_gold_train[0].reshape(-1))

        # 5.) Build generative model
        gen_model = LabelModel(cardinality=2)
        gen_model.fit(L_train[0], n_epochs=500, log_freq=100)

        train_marginals_lfs = gen_model.predict_proba(L_train[0])

        # Apply on dev-set
        labeler.apply(split=1, lfs=[station_price_lfs], clear=True, parallelism=PARALLEL)
        L_dev = labeler.get_label_matrices(dev_cands)

        L_gold_dev = labeler.get_gold_labels(dev_cands, annotator='gold')
        LFAnalysis(L=L_dev[0], lfs=sorted_lfs).lf_summary(Y=L_gold_dev[0].reshape(-1))
        return (gen_model, train_marginals_lfs)
Exemple #12
0
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)

    F = [F_train, F_dev, F_test]

    # 7.) Load gold data
    print("\n#7 Load Gold Data")
    gold = get_gold_func(gold_file,
                         attribute=ATTRIBUTE,
                         stations_mapping_dict=stations_mapping_dict)
    docs = corpus_parser.get_documents()
    labeler = Labeler(session, [StationPrice])
    labeler.apply(docs=docs,
                  lfs=[[gold]],
                  table=GoldLabel,
                  train=True,
                  parallelism=PARALLEL)

    # 8.) Rule-based evaluation (Generative Model)
    (train_model, eval_model,
     run_labeling_functions) = get_model_methods(session, ATTRIBUTE, gold,
                                                 gold_file, all_docs,
                                                 StationPrice, PARALLEL)
    if ("rule-based" in cls_methods):
        print("\n#8 Rule-based classification")
        (gen_model, train_marginals_lfs) = run_labeling_functions(cands)
        eval_LFs(train_marginals_lfs, train_cands, gold)

    # 9.) Supervised classification (Logistic Regression)
    train_marginals_gold = np.array([[0, 1] if gold(x) else [1, 0]
Exemple #13
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    # Setup initial configuration
    if gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    featurizer = Featurizer(session, cands)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True, parallelism=parallel)
        featurizer.apply(split=1, parallelism=parallel)
        featurizer.apply(split=2, parallelism=parallel)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb"))
        pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb"))
        pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb"))
    else:
        F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb"))
        F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb"))
        F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb"))
    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    labeler = Labeler(session, cands)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs, parallelism=parallel)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()
    if stg_temp_min:
        relation = "stg_temp_min"
        idx = rel_list.index(relation)
        marginals_stg_temp_min = generative_model(L_train[idx])
        disc_model_stg_temp_min = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals_stg_temp_min,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(
            relation,
            disc_model_stg_temp_min,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if stg_temp_max:
        relation = "stg_temp_max"
        idx = rel_list.index(relation)
        marginals_stg_temp_max = generative_model(L_train[idx])
        disc_model_stg_temp_max = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals_stg_temp_max,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(
            relation,
            disc_model_stg_temp_max,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if polarity:
        relation = "polarity"
        idx = rel_list.index(relation)
        marginals_polarity = generative_model(L_train[idx])
        disc_model_polarity = discriminative_model(train_cands[idx],
                                                   F_train[idx],
                                                   marginals_polarity,
                                                   n_epochs=100,
                                                   gpu=gpu)
        best_result, best_b = scoring(
            relation,
            disc_model_polarity,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if ce_v_max:
        relation = "ce_v_max"
        idx = rel_list.index(relation)

        # Can be uncommented for use in debugging labeling functions
        #  logger.info("Updating labeling function summary...")
        #  keys = labeler.get_keys()
        #  logger.info("Summary for train set labeling functions:")
        #  df = analysis.lf_summary(L_train[idx], lf_names=keys)
        #  logger.info(f"\n{df.to_string()}")
        #
        #  logger.info("Summary for dev set labeling functions:")
        #  df = analysis.lf_summary(
        #      L_dev[idx],
        #      lf_names=keys,
        #      Y=L_dev_gold[idx].todense().reshape(-1).tolist()[0],
        #  )
        #  logger.info(f"\n{df.to_string()}")
        #
        #  logger.info("Summary for test set labeling functions:")
        #  df = analysis.lf_summary(
        #      L_test[idx],
        #      lf_names=keys,
        #      Y=L_test_gold[idx].todense().reshape(-1).tolist()[0],
        #  )
        #  logger.info(f"\n{df.to_string()}")

        marginals_ce_v_max = generative_model(L_train[idx])
        disc_model_ce_v_max = discriminative_model(train_cands[idx],
                                                   F_train[idx],
                                                   marginals_ce_v_max,
                                                   n_epochs=100,
                                                   gpu=gpu)

        # Can be uncommented to view score on development set
        #  best_result, best_b = scoring(
        #      relation,
        #      disc_model_ce_v_max,
        #      dev_cands[idx],
        #      dev_docs,
        #      F_dev[idx],
        #      parts_by_doc,
        #      num=100,
        #  )

        best_result, best_b = scoring(
            relation,
            disc_model_ce_v_max,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")

    # Dump CSV files for CE_V_MAX for digi-key analysis
    if ce_v_max:
        relation = "ce_v_max"
        idx = rel_list.index(relation)
        Y_prob = disc_model_ce_v_max.marginals((test_cands[idx], F_test[idx]))
        dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
        Y_prob = disc_model_ce_v_max.marginals((dev_cands[idx], F_dev[idx]))
        dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

    # Dump CSV files for POLARITY for digi-key analysis
    if polarity:
        relation = "polarity"
        idx = rel_list.index(relation)
        Y_prob = disc_model_polarity.marginals((test_cands[idx], F_test[idx]))
        dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
        Y_prob = disc_model_polarity.marginals((dev_cands[idx], F_dev[idx]))
        dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")
Exemple #14
0
print(
    f"Number of Candidates: {len(train_cands[0])}"
)

from fonduer.features import Featurizer

featurizer = Featurizer(session, candidate_classes)
featurizer.apply(split=0, train=True, parallelism=PARALLEL)
F_train = featurizer.get_feature_matrices(train_cands)

from wiki_table_utils import gold
from fonduer.supervision import Labeler
from fonduer.supervision.models import GoldLabel

labeler = Labeler(session, candidate_classes)
labeler.apply(docs=train_docs, lfs=[[gold]], table=GoldLabel, train=True)

from fonduer_lfs import president_name_pob_lfs

labeler.apply(split=0, lfs=[president_name_pob_lfs], train=True, parallelism=PARALLEL)
L_train = labeler.get_label_matrices(train_cands)

L_gold_train = labeler.get_gold_labels(train_cands, annotator="gold")

from snorkel.labeling.model import LabelModel

label_model = LabelModel(verbose=False)
label_model.fit(L_train[0], n_epochs=500)

train_marginals = label_model.predict_proba(L_train[0])