コード例 #1
0
ファイル: test_e2e.py プロジェクト: robingong/fonduer
def test_e2e(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parsing
    if os.name == "posix":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 12

    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info("Sentences: {}".format(num_sentences))

    # Divide into test and train
    docs = corpus_parser.get_documents()
    ld = len(docs)
    assert ld == len(corpus_parser.get_last_documents())
    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 147
    assert len(mention_extractor.get_mentions()) == 2
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (
        len(
            mention_extractor.get_mentions(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 70
    )

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3684
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 72
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 448

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 3684
    assert (
        len(
            candidate_extractor.get_candidates(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 1496
    )

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 72
    assert session.query(FeatureKey).count() == 716

    # Test Dropping FeatureKey
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NFP NN NFP]"])
    assert session.query(FeatureKey).count() == 715
    session.query(Feature).delete()

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 3684
    assert session.query(FeatureKey).count() == 3748
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3684, 3748)
    assert len(featurizer.get_keys()) == 3748

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 3756
    assert session.query(FeatureKey).count() == 3748
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (72, 3748)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 4204
    assert session.query(FeatureKey).count() == 3748
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (448, 3748)

    gold_file = "tests/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session, PartTemp, gold_file, ATTRIBUTE, annotator_name="gold")
    assert session.query(GoldLabel).count() == 4204

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    with pytest.raises(ValueError):
        labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL)

    labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL)
    assert session.query(Label).count() == 3684
    assert session.query(LabelKey).count() == 6
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3684, 6)
    assert len(labeler.get_keys()) == 6

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3684, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3684, 1)

    gen_model = LabelModel(k=2)
    gen_model.train(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0, lfs=[stg_temp_lfs_2], parallelism=PARALLEL)
    assert session.query(Label).count() == 3684
    assert session.query(LabelKey).count() == 13
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3684, 13)

    gen_model = LabelModel(k=2)
    gen_model.train(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7

    # Testing LSTM
    disc_model = LSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7

    # Testing Sparse Logistic Regression
    disc_model = SparseLogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.9)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
コード例 #2
0
ファイル: test_e2e.py プロジェクト: atulgupta9/fonduer
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    PARALLEL = 4

    max_docs = 12

    fonduer.init_logging(
        log_dir="log_folder",
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (
        len(
            mention_extractor.get_mentions(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 70
    )

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]
    )

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (
        len(
            candidate_extractor.get_candidates(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 1432
    )

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey)
        .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]")
        .one()
        .candidate_classes
    ) == {"part_temp", "part_volt"}
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]"
    ).one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(
        ["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp, PartVolt]
    )
    assert set(
        session.query(FeatureKey)
        .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]")
        .one()
        .candidate_classes
    ) == {"part_temp", "part_volt"}
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    gen_model = LabelModel(k=2)
    gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]),
        train_marginals,
        X_dev=(train_cands[0], F_train[0]),
        Y_dev=np.array(L_train_gold[0].todense()).reshape(-1),
        b=0.6,
        pos_label=TRUE,
        n_epochs=5,
        lr=0.001,
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    gen_model = LabelModel(k=2)
    gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    disc_model = LSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse Logistic Regression
    disc_model = SparseLogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse LSTM
    disc_model = SparseLSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predict((test_cands[0], F_test[0]), b=0.6, pos_label=TRUE)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score == TRUE))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Evaluate mention level scores
    L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")
    Y_test = np.array(L_test_gold[0].todense()).reshape(-1)

    scores = disc_model.score((test_cands[0], F_test[0]), Y_test, b=0.6, pos_label=TRUE)

    logger.info(scores)

    assert scores["f1"] > 0.6
コード例 #3
0
                        class_balance=new_balance,
                        n_epochs=500,
                        log_train_every=50)

score = label_model.score((Ls[1], Ys[1]))

print('Trained Label Model Metrics:')
scores = label_model.score((Ls[1], Ys[1]),
                           metric=['accuracy', 'precision', 'recall', 'f1'])

mv = MajorityLabelVoter(seed=123)
print('Majority Label Voter Metrics:')
scores = mv.score((Ls[1], Ys[1]),
                  metric=['accuracy', 'precision', 'recall', 'f1'])

Y_train_ps = label_model.predict_proba(Ls[0])

Y_dev_p = label_model.predict(Ls[1])
"""
mv2 = MajorityClassVoter()
mv2.train_model(np.asarray(new_balance))
"""

#=np.asarray(new_balance))

#Y_baseline = mv2.predict(Ls[2])
pickling_on2 = open(
    "data_encompassing/ar/ar_baseline_{}{}".format(flag0, flag), "wb")
pickle.dump(Y_baseline, pickling_on2)
print(Y_baseline)
コード例 #4
0
# In[18]:

regularization_grid = pd.np.round(pd.np.linspace(0.01, 5, num=15), 2)

# In[19]:

grid_results = {}
label_model = LabelModel(k=2)
for param in tqdm_notebook(regularization_grid):
    label_model.train_model(correct_L[:, 0:7],
                            n_epochs=1000,
                            print_every=200,
                            seed=100,
                            lr=0.01,
                            l2=param)
    grid_results[str(param)] = label_model.predict_proba(correct_L_train[:,
                                                                         0:7])

# In[20]:

acc_results = defaultdict(list)

for key in grid_results:
    acc_results[key].append(
        accuracy_score(
            candidate_dfs['train']['curated_dsh'].fillna(0),
            list(map(lambda x: 1 if x > 0.5 else 0, grid_results[key][:, 0]))))
acc_df = pd.DataFrame(acc_results)
acc_df.head(2)

# In[21]:
コード例 #5
0
L_gold_train = labeler.get_gold_labels(train_cands, annotator="gold")

from metal import analysis

analysis.lf_summary(
    L_train[0],
    lf_names=labeler.get_keys(),
    Y=L_gold_train[0].todense().reshape(-1).tolist()[0],
)

from metal.label_model import LabelModel

gen_model = LabelModel(k=2)
gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

train_marginals = gen_model.predict_proba(L_train[0])

from fonduer.learning import LogisticRegression

disc_model = LogisticRegression()
disc_model.train((train_cands[0], F_train[0]), train_marginals, n_epochs=10, lr=0.001)

from my_fonduer_model import MyFonduerModel
model = MyFonduerModel()

import fonduer_model
fonduer_model.save_model(
    fonduer_model=model,
    model_path="fonduer_model",
    conn_string=conn_string,
    featurizer=featurizer,
コード例 #6
0
model_labels = ["Knowledge Bases (KB)", "KB+Text Patterns", "All"]

# In[15]:

model_grid_search = {}
for model_data, model_label in zip(validation_data, model_labels):

    label_model = LabelModel(k=2, seed=100)
    grid_results = {}
    for param in regularization_grid:
        label_model.train_model(model_data[0],
                                n_epochs=1000,
                                verbose=False,
                                lr=0.01,
                                l2=param)
        grid_results[str(param)] = label_model.predict_proba(model_data[1])[:,
                                                                            0]

    model_grid_search[model_label] = pd.DataFrame.from_dict(grid_results)

# In[16]:

model_grid_aucs = {}
for model in model_grid_search:
    model_grid_aucs[model] = plot_curve(model_grid_search[model],
                                        candidate_dfs['dev'].curated_dsh,
                                        figsize=(16, 6),
                                        model_type='scatterplot',
                                        plot_title=model,
                                        metric="ROC",
                                        font_size=10)
コード例 #7
0
L_test[L_test < 0] = 2

label_model = LabelModel(k=2, seed=100)

# In[ ]:

reg_param_grid = pd.np.round(pd.np.linspace(1e-1, 1, num=30), 3)
grid_results = defaultdict(dict)
for model in tqdm_notebook(model_dict):
    for reg_param in reg_param_grid:
        label_model.train(L[:, model_dict[model]],
                          n_epochs=1000,
                          verbose=False,
                          lr=0.01,
                          l2=reg_param)
        grid_results[model][str(reg_param)] = label_model.predict_proba(
            L_dev[:, model_dict[model]])[:, 0]

# In[ ]:

for model in grid_results:
    model_aucs = plot_roc_curve(pd.DataFrame.from_dict(grid_results[model]),
                                candidate_dfs['dev'].curated_dsh,
                                figsize=(16, 6),
                                model_type='scatterplot',
                                plot_title=model)

# In[ ]:

best_params = {
    "CtD_DB": 0.1,
    "CtD_TEXT": 0.4,
コード例 #8
0
class SnorkeMeTalCollator(Collator):
    def __init__(
        self,
        positive_label: str,
        class_cardinality: int = 2,
        num_epochs: int = 500,
        log_train_every: int = 50,
        seed: int = 123,
    ):
        self.positive_label = positive_label
        self.class_cardinality = class_cardinality
        self.num_epochs = num_epochs
        self.log_train_every = log_train_every
        self.seed = seed
        self.label_model = LabelModel(k=self.class_cardinality, seed=seed)

    @classmethod
    def get_snorkel_index(cls, tag: str) -> int:
        if is_positive(tag):
            return 2
        elif is_negative(tag):
            return 1
        else:
            return 0

    def get_tag(self, index: int) -> str:
        if index == 1:
            return self.positive_label
        else:
            return NEGATIVE_LABEL

    def get_index(self, prob: np.ndarray) -> str:
        assert prob.shape == (2, )
        return prob.argmax()

    def collate_np(self,
                   annotations) -> Tuple[np.ndarray, List[str], List[int]]:
        output_arrs: List[np.ndarray] = []
        words_list: List[str] = []
        id_to_labels: Dict[int, Tuple[int, int]] = {}
        num_funcs = len(annotations)
        for i, ann_inst in tqdm(enumerate(zip(*annotations))):
            ids = [inst['id'] for inst in ann_inst]
            inputs = [inst['input'] for inst in ann_inst]
            outputs = [inst['output'] for inst in ann_inst]
            input_len = len(inputs[0])
            entry_id = ids[0]

            # output arr = (sentence x num_labels)
            output_arr = np.zeros((input_len, num_funcs))
            for i, output in enumerate(outputs):
                for j, out_j in enumerate(output):
                    output_arr[j, i] = SnorkeMeTalCollator.get_snorkel_index(
                        out_j)

            label_start = len(words_list)
            for word_i, word in enumerate(inputs[0]):
                words_list.append(word)
            output_arrs.append(output_arr)
            label_end = len(words_list)
            id_to_labels[entry_id] = (label_start, label_end)
        output_res = np.concatenate(output_arrs, axis=0)
        return output_res, words_list, id_to_labels

    def train_label_model(
        self,
        collated_labels: np.ndarray,
        descriptions: Optional[List[str]],
        train_data_np: Optional[np.ndarray],
    ):
        sparse_labels = sparse.csr_matrix(collated_labels)
        if descriptions is not None:
            descriptions = [(i, desc) for i, desc in enumerate(descriptions)]
            logger.warn(f'labeling function order: {descriptions}')
        logger.warn(lf_summary(sparse_labels))
        self.label_model.train_model(
            sparse_labels,
            n_epochs=self.num_epochs,
            log_train_every=self.log_train_every,
            Y_dev=train_data_np,
        )

    def get_probabilistic_labels(self,
                                 collated_labels: np.ndarray) -> np.ndarray:
        sparse_labels = sparse.csr_matrix(collated_labels)
        return self.label_model.predict_proba(sparse_labels)

    def convert_to_tags(
        self,
        train_probs: np.ndarray,
        word_list: List[str],
        id_to_labels: Dict[int, Tuple[int, int]],
    ) -> List[AnnotatedDataType]:
        output = []
        for entry_id, (label_start, label_end) in id_to_labels.items():
            words = word_list[label_start:label_end]
            prob_labels = train_probs[label_start:label_end]
            label_ids = prob_labels.argmax(axis=1)
            labels = [self.get_tag(i) for i in label_ids]
            output.append({
                'id': entry_id,
                'input': words,
                'output': labels,
            })
        return output

    def collate(
            self,
            annotations: List[AnnotatedDataType],
            should_verify: bool = False,
            descriptions: Optional[List[str]] = None,
            train_data: Optional[AnnotatedDataType] = None
    ) -> AnnotatedDataType:
        '''
        args:
            ``annotations``: List[AnnotatedDataType]
                given a series of annotations, collate them into a single
                series of annotations per instance
        '''
        if should_verify:
            # make sure the annotations are in the
            # proper format
            Collator.verify_annotations(annotations)

        train_data_np = None
        if train_data:
            # if train data specified, will be used by Snorkel to estimate class balanc
            train_data_np, word_lists, id_to_labels = self.collate_np(
                [train_data])
            train_data_np = train_data_np.astype(int)
            train_data_np = train_data_np.reshape(-1)
        collate_np, word_lists, id_to_labels = self.collate_np(annotations)
        self.train_label_model(collated_labels=collate_np,
                               descriptions=descriptions,
                               train_data_np=train_data_np)
        y_train_probs = self.get_probabilistic_labels(
            collated_labels=collate_np, )
        tags = self.convert_to_tags(y_train_probs,
                                    word_list=word_lists,
                                    id_to_labels=id_to_labels)
        return tags
コード例 #9
0
ファイル: oc_train_tune.py プロジェクト: AshRamty/heart_mri
def train_model(args):

    #global args
    #args = parser.parse_args()

	hidden_size = 128 
	num_classes = 2
	encode_dim = 1000 # using get_frm_output_size()

	L,Y = load_labels(args) 

	# Label Model
	# labelling functions analysis
	print(lf_summary(L["dev"], Y = Y["dev"]))

	# training label model
	label_model = LabelModel(k=num_classes, seed=123)
	label_model.train_model(L["train"], Y["dev"], n_epochs = 500, log_train_every = 50)

	# evaluating label model
	print('Trained Label Model Metrics:')
	label_model.score((L["dev"], Y["dev"]), metric=['accuracy','precision', 'recall', 'f1'])

	# comparison with majority vote of LFs
	mv = MajorityLabelVoter(seed=123)
	print('Majority Label Voter Metrics:')
	mv.score((L["dev"], Y["dev"]), metric=['accuracy','precision', 'recall', 'f1'])

	Ytrain_p = label_model.predict_proba(L["train"])
	#print(Ytrain_ps.shape) #(377*50,2)
	#Ydev_p = label_model.predict_proba(L["dev"])

	# test models
	#label_model.score((Ltest,Ytest), metric=['accuracy','precision', 'recall', 'f1'])

	# End Model
	# Create datasets and dataloaders
	train, dev, test = load_dataset(args, Ytrain_p, Y["dev"], Y["test"])
	data_loader = get_data_loader(train, dev, test, args.batch_size, args.num_workers)
	#print(len(data_loader["train"])) # 18850 / batch_size
	#print(len(data_loader["dev"])) # 1500 / batch_size
	#print(len(data_loader["test"])) # 1000 / batch_size 
	#import ipdb; ipdb.set_trace()

	# Define input encoder
	cnn_encoder = FrameEncoderOC

	if(torch.cuda.is_available()):
		device = 'cuda'
	else:
		device = 'cpu'
	#import ipdb; ipdb.set_trace()

	# Define LSTM module
	lstm_module = LSTMModule(
		encode_dim,
		hidden_size,
		bidirectional=False,
		verbose=False,
		lstm_reduction="attention",
		encoder_class=cnn_encoder,
		)

	train_args = [data_loader["train"]]

	train_kwargs = {
	'seed':args.seed,
	'progress_bar':True,
	'log_train_every':1}

	init_args = [
	[hidden_size, num_classes]
	]

	init_kwargs = {
	"input_module": lstm_module, 
	"optimizer": "adam",
	"verbose": False,
	"input_batchnorm": True,
	"use_cuda":torch.cuda.is_available(),
	'checkpoint_dir':args.checkpoint_dir,
	'seed':args.seed,
	'device':device}
	
	search_space = {
	'n_epochs':[10],
	'batchnorm':[True],
	'dropout': [0.1,0.25,0.4],
	'lr':{'range': [1e-3, 1e-2], 'scale': 'log'}, 
	'l2':{'range': [1e-5, 1e-4], 'scale': 'log'},#[ 1.21*1e-5],
	#'checkpoint_metric':['f1'],
	}	
	
	log_config = {
	"log_dir": "./run_logs", 
	"run_name": 'cnn_lstm_oc'
	}

	max_search = 5
	tuner_config = {"max_search": max_search }

	validation_metric = 'accuracy'

	# Set up logger and searcher
	tuner = RandomSearchTuner(EndModel, 
	**log_config,
	log_writer_class=TensorBoardWriter,
	validation_metric=validation_metric,
	seed=1701)
	
	disc_model = tuner.search(
	search_space,
	valid_data = data_loader["dev"],
	train_args=train_args,
	init_args=init_args,
	init_kwargs=init_kwargs,
	train_kwargs=train_kwargs,
	max_search=tuner_config["max_search"],
	clean_up=False,
	)

	# evaluate end model
	disc_model.score(data_loader["dev"], verbose=True, metric=['accuracy','precision', 'recall', 'f1'])
コード例 #10
0
train_ground = remap_labels(loader.train_ground)
val_ground = remap_labels(loader.val_ground)
L_train_sparse = sparse.csc_matrix(
    (remap_labels(L_train_sparse.data), L_train_sparse.indices,
     L_train_sparse.indptr)).T
L_val_sparse = sparse.csc_matrix((remap_labels(L_val_sparse.data),
                                  L_val_sparse.indices, L_val_sparse.indptr)).T

print('\n\n####### Running METAL Label Model ########')
label_model = LabelModel()
label_model.train_model(L_train_sparse,
                        n_epochs=200,
                        print_every=50,
                        seed=123,
                        verbose=False)
train_marginals = label_model.predict_proba(L_train_sparse)
label_model.score((L_train_sparse, train_ground), metric=metrics)

####### METAL with Exact Class Balance ########
print(
    '\n\n####### Running METAL Label Model with exact class balance ########')
train_class_balance = np.array([
    np.sum(train_ground == 1) / loader.train_num,
    np.sum(train_ground == 2) / loader.train_num
])
val_class_balance = np.array([
    np.sum(val_ground == 1) / loader.val_num,
    np.sum(val_ground == 2) / loader.val_num
])
print('Train set class balance:', train_class_balance)
print('Val set class balance:', val_class_balance)
コード例 #11
0
ファイル: oc_train.py プロジェクト: AshRamty/heart_mri
def train_model(args):

    #global args
    #args = parser.parse_args()

    hidden_size = 128
    num_classes = 2
    encode_dim = 1000  # using get_frm_output_size()

    L, Y = load_labels(args)

    # Label Model
    # labelling functions analysis
    print(lf_summary(L["dev"], Y=Y["dev"]))

    # training label model
    label_model = LabelModel(k=num_classes, seed=123)
    label_model.train_model(L["train"],
                            Y["dev"],
                            n_epochs=2000,
                            log_train_every=100)

    # evaluating label model
    print('Trained Label Model Metrics:')
    label_model.score((L["dev"], Y["dev"]),
                      metric=['accuracy', 'precision', 'recall', 'f1'])

    # comparison with majority vote of LFs
    mv = MajorityLabelVoter(seed=123)
    print('Majority Label Voter Metrics:')
    mv.score((L["dev"], Y["dev"]),
             metric=['accuracy', 'precision', 'recall', 'f1'])

    Ytrain_p = label_model.predict_proba(L["train"])
    #print(Ytrain_ps.shape) #(377*50,2)
    #Ydev_p = label_model.predict_proba(L["dev"])

    # test models
    #label_model.score((Ltest,Ytest), metric=['accuracy','precision', 'recall', 'f1'])

    # End Model
    # Create datasets and dataloaders
    train, dev, test = load_dataset(args, Ytrain_p, Y["dev"], Y["test"])
    data_loader = get_data_loader(train, dev, test, args.batch_size,
                                  args.num_workers)
    #print(len(data_loader["train"])) # 18850 / batch_size
    #print(len(data_loader["dev"])) # 1500 / batch_size
    #print(len(data_loader["test"])) # 1000 / batch_size
    #import ipdb; ipdb.set_trace()

    # Define input encoder
    #cnn_encoder = FrameEncoderOC
    cnn_encoder = FrameEncoderOCDense

    if (torch.cuda.is_available()):
        device = 'cuda'
    else:
        device = 'cpu'
    #import ipdb; ipdb.set_trace()

    # Define LSTM module
    lstm_module = LSTMModule(
        encode_dim,
        hidden_size,
        bidirectional=False,
        verbose=False,
        lstm_reduction=args.lstm_reduction,
        encoder_class=cnn_encoder,
        encoder_kwargs={"requires_grad": args.requires_grad})
    '''
	# Define end model
	end_model = EndModel(
		input_module=lstm_module,
		layer_out_dims=[hidden_size, num_classes],
		optimizer="adam",
		#use_cuda=cuda,
		batchnorm=False,
		seed=args.seed,
		verbose=False,
		device = device,
		)
	'''

    init_kwargs = {
        "layer_out_dims": [hidden_size, num_classes],
        "input_module": lstm_module,
        "optimizer": "adam",
        "verbose": False,
        "input_batchnorm": False,
        "use_cuda": cuda,
        'seed': args.seed,
        'device': device
    }

    end_model = EndModel(**init_kwargs)

    if not os.path.exists(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)

    with open(args.checkpoint_dir + '/init_kwargs.pickle', "wb") as f:
        pickle.dump(init_kwargs, f, protocol=pickle.HIGHEST_PROTOCOL)

    dropout = 0.4
    # Train end model
    end_model.train_model(
        train_data=data_loader["train"],
        valid_data=data_loader["dev"],
        l2=args.weight_decay,
        lr=args.lr,
        n_epochs=args.n_epochs,
        log_train_every=1,
        verbose=True,
        progress_bar=True,
        loss_weights=[0.55, 0.45],
        input_dropout=0.1,
        middle_dropout=dropout,
        checkpoint_dir=args.checkpoint_dir,
        #writer = "json",
        #writer_config = {
        #"log_dir":  args.log_dir,
        #"run_dir":  args.run_dir,
        #"run_name": args.run_name,
        #"writer_metrics": ['accuracy','precision', 'recall', 'f1','roc-auc','ndcg']
        #},
        #validation_metric='f1',
    )

    # evaluate end model
    print("Dev Set Performance")
    end_model.score(
        data_loader["dev"],
        verbose=True,
        metric=['accuracy', 'precision', 'recall', 'f1', 'roc-auc', 'ndcg'])
    print("Test Set Performance")
    end_model.score(
        data_loader["test"],
        verbose=True,
        metric=['accuracy', 'precision', 'recall', 'f1', 'roc-auc', 'ndcg'])