Пример #1
0
def main(args):
    # Initialize Emmental
    config = parse_args_to_config(args)
    emmental.init(log_dir=config["meta_config"]["log_path"], config=config)

    # Log configuration into files
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg)

    logger.info(f"Config: {emmental.Meta.config}")
    write_to_file(f"{emmental.Meta.log_path}/config.txt", emmental.Meta.config)

    # Create dataloaders
    dataloaders = get_dataloaders(args)

    # Assign transforms to dataloaders
    aug_dataloaders = []
    if args.augment_policy:
        for idx in range(len(dataloaders)):
            if dataloaders[idx].split in args.train_split:
                dataloaders[idx].dataset.transform_cls = Augmentation(
                    args=args)

    config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = AugScheduler(augment_k=args.augment_k,
                                         enlarge=args.augment_enlarge)
    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = config["learner_config"]["task_scheduler_config"][
            "task_scheduler"]

    # Create tasks
    model = EmmentalModel(name=f"{args.task}_task")
    model.add_task(create_task(args))

    # Set cudnn benchmark
    cudnn.benchmark = True

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders + aug_dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None

    scores = model.score(dataloaders)

    # Save metrics and models
    logger.info(f"Metrics: {scores}")
    scores["log_path"] = emmental.Meta.log_path
    write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores)
    model.save(f"{emmental.Meta.log_path}/last_model.pth")
Пример #2
0
def main(args):
    # Initialize Emmental
    config = parse_args_to_config(args)
    emmental.init(log_dir=config["meta_config"]["log_path"], config=config)

    # Log configuration into files
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg)

    logger.info(f"Config: {emmental.Meta.config}")
    write_to_file(f"{emmental.Meta.log_path}/config.txt", emmental.Meta.config)

    # Create dataloaders
    dataloaders = get_dataloaders(args)

    config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = AugScheduler(augment_k=args.augment_k,
                                         enlarge=args.augment_enlarge)
    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = config["learner_config"]["task_scheduler_config"][
            "task_scheduler"]

    # Specify parameter group for Adam BERT
    def grouped_parameters(model):
        no_decay = ["bias", "LayerNorm.weight"]
        return [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                emmental.Meta.config["learner_config"]["optimizer_config"]
                ["l2"],
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Create tasks
    model = EmmentalModel(name=f"{args.task}_task")
    model.add_task(create_task(args))

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None
        dataloaders[idx].dataset.k = 1

    scores = model.score(dataloaders)

    # Save metrics and models
    logger.info(f"Metrics: {scores}")
    scores["log_path"] = emmental.Meta.log_path
    write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores)
    model.save(f"{emmental.Meta.log_path}/last_model.pth")
Пример #3
0
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 12

    fonduer.init_logging(
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (len(
        mention_extractor.get_mentions(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 70)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (len(
        candidate_extractor.get_candidates(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 1432)

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                           candidate_classes=[PartTemp, PartVolt])
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    # Collect word counter
    word_counter = collect_word_counter(train_cands)

    emmental.init(fonduer.Meta.log_path)

    # Training config
    config = {
        "meta_config": {
            "verbose": False
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0],
                               emb_layer.word2id, 2),
        split="test",
        batch_size=100,
        shuffle=False,
    )

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    valid_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            np.argmax(train_marginals, axis=1),
            train_idxs,
        ),
        split="valid",
        batch_size=100,
        shuffle=False,
    )

    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7
Пример #4
0
        )

    tasks = {
        task_name: create_task(
            task_name, args, datasets[task_name]["nclasses"], emb_layer
        )
        for task_name in args.task
    }

    model = EmmentalModel(name="TC_task")

    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])
    else:
        for task_name, task in tasks.items():
            model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)
    logger.info(f"Metrics: {scores}")
    write_to_json_file(f"{Meta.log_path}/metrics.txt", scores)

    if args.checkpointing:
        logger.info(
            f"Best metrics: "
            f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}"
        )
        write_to_file(
            f"{Meta.log_path}/best_metrics.txt",
Пример #5
0
                    split=split,
                    batch_size=args.batch_size,
                    shuffle=True if split == "train" else False,
                )
            )
            logger.info(f"Built dataloader for {task_name} {split} set.")

    tasks = get_gule_task(args.task, args.bert_model)

    mtl_model = EmmentalModel(name="GLUE_multi_task")

    if Meta.config["model_config"]["model_path"]:
        mtl_model.load(Meta.config["model_config"]["model_path"])
    else:
        for task_name, task in tasks.items():
            mtl_model.add_task(task)

    emmental_learner = EmmentalLearner()

    emmental_learner.learn(mtl_model, dataloaders)

    scores = mtl_model.score(dataloaders)
    logger.info(f"Metrics: {scores}")
    write_to_file("metrics.txt", scores)
    logger.info(
        f"Best metrics: "
        f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}"
    )
    write_to_file(
        "best_metrics.txt",
        emmental_learner.logging_manager.checkpointer.best_metric_dict,
Пример #6
0
def test_model(caplog):
    """Unit test of model."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_model"

    Meta.reset()
    emmental.init(dirpath)

    def ce_loss(module_name, immediate_output_dict, Y, active):
        return F.cross_entropy(immediate_output_dict[module_name][0][active],
                               (Y.view(-1))[active])

    def output(module_name, immediate_output_dict):
        return F.softmax(immediate_output_dict[module_name][0], dim=1)

    task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 10, bias=False),
            "m2": nn.Linear(10, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    new_task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    task2 = EmmentalTask(
        name="task_2",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    config = {"model_config": {"dataparallel": False}}
    emmental.Meta.update_config(config)

    model = EmmentalModel(name="test", tasks=task1)

    assert repr(model) == "EmmentalModel(name=test)"
    assert model.name == "test"
    assert model.task_names == set(["task_1"])
    assert model.module_pool["m1"].weight.data.size() == (10, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 10)

    model.update_task(new_task1)

    assert model.module_pool["m1"].weight.data.size() == (5, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 5)

    model.update_task(task2)

    assert model.task_names == set(["task_1"])

    model.add_task(task2)

    assert model.task_names == set(["task_1", "task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.save(f"{dirpath}/saved_model.pth")

    model.load(f"{dirpath}/saved_model.pth")

    # Test add_tasks
    model = EmmentalModel(name="test")

    model.add_tasks([task1, task2])
    assert model.task_names == set(["task_1", "task_2"])

    shutil.rmtree(dirpath)
Пример #7
0
    def train_model(cands, F, align_type, model_type="LogisticRegression"):
        # Extract candidates and features based on the align type (row/column)
        align_val = 0 if align_type == "row" else 1
        train_cands = cands[align_val][0]
        F_train = F[align_val][0]
        train_marginals = np.array([[0, 1] if gold[align_val](x) else [1, 0]
                                    for x in train_cands[0]])

        # 1.) Setup training config
        config = {
            "meta_config": {
                "verbose": True
            },
            "model_config": {
                "model_path": None,
                "device": 0,
                "dataparallel": False
            },
            "learner_config": {
                "n_epochs": 50,
                "optimizer_config": {
                    "lr": 0.001,
                    "l2": 0.0
                },
                "task_scheduler": "round_robin",
            },
            "logging_config": {
                "evaluation_freq": 1,
                "counter_unit": "epoch",
                "checkpointing": False,
                "checkpointer_config": {
                    "checkpoint_metric": {
                        f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                    },
                    "checkpoint_freq": 1,
                    "checkpoint_runway": 2,
                    "clear_intermediate_checkpoints": True,
                    "clear_all_checkpoints": True,
                },
            },
        }

        emmental.init(Meta.log_path)
        emmental.Meta.update_config(config=config)

        # 2.) Collect word counter from training data
        word_counter = collect_word_counter(train_cands)

        # 3.) Generate word embedding module for LSTM model
        # (in Logistic Regression, we generate it since Fonduer dataset requires word2id dict)
        # Geneate special tokens
        arity = 2
        specials = []
        for i in range(arity):
            specials += [f"~~[[{i}", f"{i}]]~~"]

        emb_layer = EmbeddingModule(word_counter=word_counter,
                                    word_dim=300,
                                    specials=specials)

        # 4.) Generate dataloader for training set
        # No noise in Gold labels
        train_dataloader = EmmentalDataLoader(
            task_to_label_dict={ATTRIBUTE: "labels"},
            dataset=FonduerDataset(
                ATTRIBUTE,
                train_cands[0],
                F_train[0],
                emb_layer.word2id,
                train_marginals,
            ),
            split="train",
            batch_size=100,
            shuffle=True,
        )

        # 5.) Training
        tasks = create_task(
            ATTRIBUTE,
            2,
            F_train[0].shape[1],
            2,
            emb_layer,
            model=model_type  # "LSTM" 
        )

        model = EmmentalModel(name=f"{ATTRIBUTE}_task")

        for task in tasks:
            model.add_task(task)

        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, [train_dataloader])

        return (model, emb_layer)
Пример #8
0
def main(
    conn_string,
    gain=False,
    current=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=8,
    log_dir="logs",
    verbose=False,
):
    # Setup initial configuration
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if gain:
        rel_list.append("gain")

    if current:
        rel_list.append("current")

    logger.info(f"=" * 30)
    logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}")

    session = Meta.init(conn_string).Session()

    # Parsing
    start = timer()
    logger.info(f"Starting parsing...")
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    logger.debug(f"Done")
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of Documents: {len(docs)}")
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    if gain:
        Gain = mention_subclass("Gain")
        gain_matcher = get_gain_matcher()
        gain_ngrams = MentionNgrams(n_max=2)
        mentions.append(Gain)
        ngrams.append(gain_ngrams)
        matchers.append(gain_matcher)

    if current:
        Current = mention_subclass("SupplyCurrent")
        current_matcher = get_supply_current_matcher()
        current_ngrams = MentionNgramsCurrent(n_max=3)
        mentions.append(Current)
        ngrams.append(current_ngrams)
        matchers.append(current_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")

    if gain:
        logger.info(f"Total Gain: {session.query(Gain).count()}")

    if current:
        logger.info(f"Total Current: {session.query(Current).count()}")

    cand_classes = []
    if gain:
        GainCand = candidate_subclass("GainCand", [Gain])
        cand_classes.append(GainCand)
    if current:
        CurrentCand = candidate_subclass("CurrentCand", [Current])
        cand_classes.append(CurrentCand)

    candidate_extractor = CandidateExtractor(session, cand_classes)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    logger.info(
        f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}")
    logger.info(
        f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}")
    logger.info(
        f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}")

    logger.info("Done w/ candidate extraction.")
    end = timer()
    logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}")

    # First, check total recall
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[0], is_gain=True),
    #      corpus=dev_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[0], is_gain=True),
    #      corpus=test_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[1], is_gain=False),
    #      corpus=dev_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[1], is_gain=False),
    #      corpus=test_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")

    start = timer()

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cand_classes, parallelism=1)

    if first_time:
        logger.info("Starting featurizer...")
        # Set feature space based on dev set, which we use for training rather
        # than the large train set.
        featurizer.apply(split=1, train=True)
        featurizer.apply(split=0)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    # Serialize feature matrices on first run
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    start = timer()
    logger.info("Labeling training data...")
    #  labeler = Labeler(session, cand_classes)
    #  lfs = []
    #  if gain:
    #      lfs.append(gain_lfs)
    #
    #  if current:
    #      lfs.append(current_lfs)
    #
    #  if first_time:
    #      logger.info("Applying LFs...")
    #      labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
    #  elif re_label:
    #      logger.info("Re-applying LFs...")
    #      labeler.update(split=0, lfs=lfs, parallelism=parallel)
    #
    #  logger.info("Done...")

    #  logger.info("Getting label matrices...")
    #  L_train = labeler.get_label_matrices(train_cands)
    #  logger.info("Done...")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            # Manually create marginals from human annotations
            marginal = []
            dev_gold_entities = get_gold_set(is_gain=(relation == "gain"))
            for c in dev_cands[idx]:
                flag = False
                for entity in cand_to_entity(c, is_gain=(relation == "gain")):
                    if entity in dev_gold_entities:
                        flag = True

                if flag:
                    marginal.append([0.0, 1.0])
                else:
                    marginal.append([1.0, 0.0])

            marginals_dict[relation] = np.array(marginal)

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(
        f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 30
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 500,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.005
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        # only uses dev set as training data, with human annotations
        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    dev_cands[idx],
                    F_dev[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=256,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"opamp_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each relation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=256,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            test_preds,
            test_cands[idx],
            test_docs,
            is_gain=(relation == "gain"),
            num=100,
        )

        # Dump CSV files for analysis
        if relation == "gain":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "gain_test_probs.csv",
                            is_gain=True)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "gain_dev_probs.csv",
                            is_gain=True)

        if relation == "current":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=False)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "current_test_probs.csv",
                            is_gain=False)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "current_dev_probs.csv",
                            is_gain=False)

    end = timer()
    logger.warning(
        f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}"
    )
Пример #9
0
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Create tasks
    model = EmmentalModel(name="TACRED_task")
    model.add_task(create_task(args))

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None

    scores = model.score(dataloaders)
Пример #10
0
class BootlegAnnotator(object):
    """BootlegAnnotator class: convenient wrapper of preprocessing and model
    eval to allow for annotating single sentences at a time for quick
    experimentation, e.g. in notebooks.

    Args:
        config: model config (default None)
        device: model device, -1 for CPU (default None)
        max_alias_len: maximum alias length (default 6)
        cand_map: alias candidate map (default None)
        threshold: probability threshold (default 0.0)
        cache_dir: cache directory (default None)
        model_name: model name (default None)
        verbose: verbose boolean (default False)
    """
    def __init__(
        self,
        config=None,
        device=None,
        max_alias_len=6,
        cand_map=None,
        threshold=0.0,
        cache_dir=None,
        model_name=None,
        verbose=False,
    ):
        self.max_alias_len = (
            max_alias_len  # minimum probability of prediction to return mention
        )
        self.verbose = verbose
        self.threshold = threshold

        if not cache_dir:
            self.cache_dir = get_default_cache()
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"
        else:
            self.cache_dir = Path(cache_dir)
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"

        if not model_name:
            model_name = "bootleg_uncased"

        assert model_name in {
            "bootleg_cased",
            "bootleg_cased_mini",
            "bootleg_uncased",
            "bootleg_uncased_mini",
        }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, "
            f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.")

        if not config:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.model_path.mkdir(parents=True, exist_ok=True)
            self.data_path.mkdir(parents=True, exist_ok=True)
            create_sources(self.model_path, self.data_path, model_name)
            self.config = create_config(self.model_path, self.data_path,
                                        model_name)
        else:
            if "emmental" in config:
                config = parse_boot_and_emm_args(config)
            self.config = config
            # Ensure some of the critical annotator args are the correct type
            self.config.data_config.max_aliases = int(
                self.config.data_config.max_aliases)
            self.config.run_config.eval_batch_size = int(
                self.config.run_config.eval_batch_size)
            self.config.data_config.max_seq_len = int(
                self.config.data_config.max_seq_len)
            self.config.data_config.train_in_candidates = bool(
                self.config.data_config.train_in_candidates)

        if not device:
            device = 0 if torch.cuda.is_available() else -1

        if self.verbose:
            self.config.run_config.log_level = "DEBUG"
        else:
            self.config.run_config.log_level = "INFO"

        self.torch_device = (torch.device(device)
                             if device != -1 else torch.device("cpu"))
        self.config.model_config.device = device

        log_level = logging.getLevelName(
            self.config["run_config"]["log_level"].upper())
        emmental.init(
            log_dir=self.config["meta_config"]["log_path"],
            config=self.config,
            use_exact_log_path=self.config["meta_config"]
            ["use_exact_log_path"],
            level=log_level,
        )

        logger.debug("Reading entity database")
        self.entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                self.config.data_config.entity_dir,
                self.config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=self.config.data_config.alias_cand_map,
            alias_idx_file=self.config.data_config.alias_idx_map,
        )
        logger.debug("Reading word tokenizers")
        self.tokenizer = BertTokenizer.from_pretrained(
            self.config.data_config.word_embedding.bert_model,
            do_lower_case=True if "uncased"
            in self.config.data_config.word_embedding.bert_model else False,
            cache_dir=self.config.data_config.word_embedding.cache_dir,
        )

        # Create tasks
        tasks = [NED_TASK]
        if self.config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)
        self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}

        # Create tasks
        self.model = EmmentalModel(name="Bootleg")
        self.model.add_task(ned_task.create_task(self.config, self.entity_db))
        if TYPE_PRED_TASK in tasks:
            self.model.add_task(
                type_pred_task.create_task(self.config, self.entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(self.model)

        logger.debug("Loading model")
        # Load the best model from the pretrained model
        assert (
            self.config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        self.model.load(self.config["model_config"]["model_path"])
        self.model.eval()
        if cand_map is None:
            alias_map = self.entity_db.get_alias2qids()
        else:
            logger.debug(f"Loading candidate map")
            alias_map = ujson.load(open(cand_map))

        self.all_aliases_trie = get_all_aliases(alias_map, verbose)

        logger.debug("Reading in alias table")
        self.alias2cands = AliasEntityTable(
            data_config=self.config.data_config, entity_symbols=self.entity_db)

        # get batch_on_the_fly embeddings
        self.batch_on_the_fly_embs = get_dataloader_embeddings(
            self.config, self.entity_db)

    def extract_mentions(self, text, label_func):
        """Wrapper function for mention extraction.

        Args:
            text: text to extract mentions from
            label_func: function that performs extraction (input is (text, alias trie, max alias length) ->
                        output is list of found aliases and found spans

        Returns: JSON object of sentence to be used in eval
        """
        found_aliases, found_spans = label_func(text, self.all_aliases_trie,
                                                self.max_alias_len)
        return {
            "sentence": text,
            "aliases": found_aliases,
            "spans": found_spans,
            # we don't know the true QID
            "qids": ["Q-1" for i in range(len(found_aliases))],
            "gold": [True for i in range(len(found_aliases))],
        }

    def set_threshold(self, value):
        """Sets threshold.

        Args:
            value: threshold value

        Returns:
        """
        self.threshold = value

    def label_mentions(self,
                       text_list,
                       label_func=find_aliases_in_sentence_tag):
        """Extracts mentions and runs disambiguation.

        Args:
            text_list: list of text to disambiguate (or single sentence)
            label_func: mention extraction funciton (optional)

        Returns: Dict of

            * ``qids``: final predicted QIDs,
            * ``probs``: final predicted probs,
            * ``titles``: final predicted titles,
            * ``cands``: all entity canddiates,
            * ``cand_probs``: probabilities of all candidates,
            * ``spans``: final extracted word spans,
            * ``aliases``: final extracted aliases,
        """
        if type(text_list) is str:
            text_list = [text_list]
        else:
            assert (type(text_list) is list and len(text_list) > 0
                    and type(text_list[0]) is str
                    ), f"We only accept inputs of strings and lists of strings"

        ebs = int(self.config.run_config.eval_batch_size)
        self.config.data_config.max_aliases = int(
            self.config.data_config.max_aliases)
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_aliases = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_aliases_arr = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq, text in tqdm(
                enumerate(text_list),
                desc="Prepping data",
                total=len(text_list),
                disable=not self.verbose,
        ):
            sample = self.extract_mentions(text, label_func)
            total_start_exs += len(sample["aliases"])
            char_spans = self.get_char_spans(sample["spans"], text)

            final_char_spans.append(char_spans)

            (
                idxs_arr,
                aliases_to_predict_per_split,
                spans_arr,
                phrase_tokens_arr,
                pos_idxs,
            ) = sentence_utils.split_sentence(
                max_aliases=self.config.data_config.max_aliases,
                phrase=sample["sentence"],
                spans=sample["spans"],
                aliases=sample["aliases"],
                aliases_seen_by_model=list(range(len(sample["aliases"]))),
                seq_len=self.config.data_config.max_seq_len,
                is_bert=True,
                tokenizer=self.tokenizer,
            )
            aliases_arr = [[sample["aliases"][idx] for idx in idxs]
                           for idxs in idxs_arr]
            old_spans_arr = [[sample["spans"][idx] for idx in idxs]
                             for idxs in idxs_arr]
            qids_arr = [[sample["qids"][idx] for idx in idxs]
                        for idxs in idxs_arr]
            word_indices_arr = [
                self.tokenizer.convert_tokens_to_ids(pt)
                for pt in phrase_tokens_arr
            ]
            # iterate over each sample in the split

            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert (
                    len(aliases_to_predict_arr) >= 0
                ), f"There are no aliases to predict for an example. This should not happen at this point."
                assert (
                    len(aliases_arr[sub_idx]) <=
                    self.config.data_config.max_aliases
                ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases."

                example_aliases = np.ones(
                    self.config.data_config.max_aliases) * PAD_ID
                example_aliases_locs_start = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_aliases_locs_end = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_alias_list_pos = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_true_entities = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)

                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][
                        mention_idx]
                    # generate indexes into alias table.
                    alias_trie_idx = self.entity_db.get_alias_idx(alias)
                    alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.config.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one
                        # (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is
                        # reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(
                            alias_qids == qids_arr[sub_idx][mention_idx]
                        )[0][0] + (
                            not self.config.data_config.train_in_candidates)
                    example_aliases[mention_idx] = alias_trie_idx
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][
                        mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence
                    # and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_aliases.append(example_aliases)
                batch_example_aliases_locs_start.append(
                    example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_aliases_arr.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_aliases = torch.tensor(batch_example_aliases).long()
        batch_example_aliases_locs_start = torch.tensor(
            batch_example_aliases_locs_start, device=self.torch_device)
        batch_example_aliases_locs_end = torch.tensor(
            batch_example_aliases_locs_end, device=self.torch_device)
        batch_example_true_entities = torch.tensor(batch_example_true_entities,
                                                   device=self.torch_device)
        batch_word_indices = torch.tensor(batch_word_indices,
                                          device=self.torch_device)

        final_pred_cands = [[] for _ in range(len(text_list))]
        final_all_cands = [[] for _ in range(len(text_list))]
        final_cand_probs = [[] for _ in range(len(text_list))]
        final_pred_probs = [[] for _ in range(len(text_list))]
        final_titles = [[] for _ in range(len(text_list))]
        final_spans = [[] for _ in range(len(text_list))]
        final_aliases = [[] for _ in range(len(text_list))]
        for b_i in tqdm(
                range(0, batch_example_aliases.shape[0], ebs),
                desc="Evaluating model",
                disable=not self.verbose,
        ):
            start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs]
            end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            alias_indices = batch_example_aliases[b_i:b_i + ebs]
            x_dict = self.get_forward_batch(start_span_idx, end_span_idx,
                                            word_indices, alias_indices)
            x_dict["guid"] = torch.arange(b_i,
                                          b_i + ebs,
                                          device=self.torch_device)

            (uid_bdict, _, prob_bdict, _) = self.model(  # type: ignore
                uids=x_dict["guid"],
                X_dict=x_dict,
                Y_dict=None,
                task_to_label_dict=self.task_to_label_dict,
                return_action_outputs=False,
            )
            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================
            # recover predictions
            probs = prob_bdict[NED_TASK]
            max_probs = probs.max(2)
            max_probs_indices = probs.argmax(2)
            for ex_i in range(probs.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                entity_cands = eval_utils.map_aliases_to_candidates(
                    self.config.data_config.train_in_candidates,
                    self.config.data_config.max_aliases,
                    self.entity_db.get_alias2qids(),
                    batch_aliases_arr[b_i + ex_i],
                )
                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].reshape(
                    self.config.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(
                        batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(
                                probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            final_aliases[idx_unq].append(
                                batch_aliases_arr[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(
                                batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(
                                self.entity_db.get_title(pred_qid)
                                if pred_qid != "NC" else "NC")
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, (
            f"Something went wrong and we have predicted fewer mentions than extracted. "
            f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        )
        res_dict = {
            "qids": final_pred_cands,
            "probs": final_pred_probs,
            "titles": final_titles,
            "cands": final_all_cands,
            "cand_probs": final_cand_probs,
            "spans": final_spans,
            "aliases": final_aliases,
        }
        return res_dict

    def get_forward_batch(self, start_span_idx, end_span_idx, token_ids,
                          alias_idx):
        """Preps the forward batch for disambiguation.

        Args:
            start_span_idx: start span tensor
            end_span_idx: end span tensor
            token_ids: word token tensor
            alias_idx: alias index used for extracting candidate eids

        Returns: X_dict used in Emmental
        """
        entity_cand_eid = self.alias2cands(alias_idx).long()
        entity_cand_eid_mask = entity_cand_eid == -1
        entity_cand_eid_noneg = torch.where(
            entity_cand_eid >= 0,
            entity_cand_eid,
            (torch.ones_like(entity_cand_eid, dtype=torch.long) *
             (self.entity_db.num_entities_with_pad_and_nocand - 1)),
        )

        kg_prepped_embs = {}
        for emb_key in self.batch_on_the_fly_embs:
            kg_adj = self.batch_on_the_fly_embs[emb_key]["kg_adj"]
            prep_func = self.batch_on_the_fly_embs[emb_key][
                "kg_adj_process_func"]
            batch_prep = []
            for j in range(entity_cand_eid_noneg.shape[0]):
                batch_prep.append(
                    prep_func(entity_cand_eid_noneg[j].cpu(),
                              kg_adj).reshape(1, -1))
            kg_prepped_embs[emb_key] = torch.tensor(batch_prep,
                                                    device=self.torch_device)

        X_dict = {
            "guids": [],
            "start_span_idx": start_span_idx,
            "end_span_idx": end_span_idx,
            "token_ids": token_ids,
            "entity_cand_eid": entity_cand_eid_noneg,
            "entity_cand_eid_mask": entity_cand_eid_mask,
            "batch_on_the_fly_kg_adj": kg_prepped_embs,
        }
        return X_dict

    def get_char_spans(self, spans, text):
        """Helper function to get character spans instead of default word
        spans.

        Args:
            spans: word spans
            text: text

        Returns: character spans
        """
        query_toks = text.split()
        char_spans = []
        for span in spans:
            space_btwn_toks = (len(" ".join(query_toks[0:span[0] + 1])) -
                               len(" ".join(query_toks[0:span[0]])) -
                               len(query_toks[span[0]]))
            char_b = len(" ".join(query_toks[0:span[0]])) + space_btwn_toks
            char_e = char_b + len(" ".join(query_toks[span[0]:span[1]]))
            char_spans.append([char_b, char_e])
        return char_spans
Пример #11
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cands, parallelism=1)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True)
        featurizer.apply(split=1)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    # Using parallelism = 1 for deterministic behavior.
    labeler = Labeler(session, cands, parallelism=1)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            marginals_dict[relation] = generative_model(L_train[idx])

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 17
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    train_cands[idx],
                    F_train[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=100,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"transistor_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each rlation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=100,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            relation,
            test_preds,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

        # Dump CSV files for CE_V_MAX for digi-key analysis
        if relation == "ce_v_max":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

        # Dump CSV files for POLARITY for digi-key analysis
        if relation == "polarity":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")
Пример #12
0
def main(
    conn_string,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(
        session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs
    )
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")

    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    start = timer()

    Thumbnails = mention_subclass("Thumbnails")

    thumbnails_img = MentionFigures()

    class HasFigures(_Matcher):
        def _f(self, m):
            file_path = ""
            for prefix in [
                f"{dirname}/data/train/html/",
                f"{dirname}/data/dev/html/",
                f"{dirname}/data/test/html/",
            ]:
                if os.path.exists(prefix + m.figure.url):
                    file_path = prefix + m.figure.url
            if file_path == "":
                return False
            img = Image.open(file_path)
            width, height = img.size
            min_value = min(width, height)
            return min_value > 50

    mention_extractor = MentionExtractor(
        session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel
    )

    if first_time:
        mention_extractor.apply(docs)

    logger.info("Total Mentions: {}".format(session.query(Mention).count()))

    ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails])

    candidate_extractor = CandidateExtractor(
        session, [ThumbnailLabel], throttlers=[None], parallelism=parallel
    )

    if first_time:
        candidate_extractor.apply(train_docs, split=0)
        candidate_extractor.apply(dev_docs, split=1)
        candidate_extractor.apply(test_docs, split=2)

    train_cands = candidate_extractor.get_candidates(split=0)
    # Sort the dev_cands, which are used for training, for deterministic behavior
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2)

    end = timer()
    logger.warning(f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info("Total train candidate:\t{}".format(len(train_cands[0])))
    logger.info("Total dev candidate:\t{}".format(len(dev_cands[0])))
    logger.info("Total test candidate:\t{}".format(len(test_cands[0])))

    fin = open(f"{dirname}/data/ground_truth.txt", "r")
    gt = set()
    for line in fin:
        gt.add("::".join(line.lower().split()))
    fin.close()

    # Labeling
    start = timer()

    def LF_gt_label(c):
        doc_file_id = (
            f"{c[0].context.figure.document.name.lower()}.pdf::"
            f"{os.path.basename(c[0].context.figure.url.lower())}"
        )
        return TRUE if doc_file_id in gt else FALSE

    gt_dev = [LF_gt_label(cand) for cand in dev_cands[0]]
    gt_test = [LF_gt_label(cand) for cand in test_cands[0]]

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    batch_size = 64
    input_size = 224
    K = 2

    emmental.init(log_dir=Meta.log_path, config=emmental_config)

    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"
    ] = DauphinScheduler(augment_k=K, enlarge=1)

    train_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "train",
        prob_label=True,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        transform_cls=Augmentation(2),
        k=K,
    )

    val_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "valid",
        prob_label=False,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        k=1,
    )

    test_dataset = ThumbnailDataset(
        "Thumbnail",
        test_cands[0],
        gt_test,
        "test",
        prob_label=False,
        prefix=f"{dirname}/data/test/html/",
        input_size=input_size,
        k=1,
    )

    dataloaders = []

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=train_dataset,
            split="train",
            shuffle=True,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=val_dataset,
            split="valid",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=test_dataset,
            split="test",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    model = EmmentalModel(name=f"Thumbnail")
    model.add_task(
        create_task("Thumbnail", n_class=2, model="resnet18", pretrained=True)
    )

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)

    logger.warning("Model Score:")
    logger.warning(f"precision: {scores['Thumbnail/Thumbnail/test/precision']:.3f}")
    logger.warning(f"recall: {scores['Thumbnail/Thumbnail/test/recall']:.3f}")
    logger.warning(f"f1: {scores['Thumbnail/Thumbnail/test/f1']:.3f}")
Пример #13
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file
Пример #14
0
        train_marginals,
        train_idxs,
    ),
    split="train",
    batch_size=100,
    shuffle=True,
)

tasks = create_task(
    ATTRIBUTE, 2, F_train[0].shape[1], 2, emb_layer, model="LogisticRegression"
)

emmental_model = EmmentalModel()

for task in tasks:
    emmental_model.add_task(task)

emmental_learner = EmmentalLearner()
emmental_learner.learn(emmental_model, [train_dataloader])

from my_fonduer_model import MyFonduerModel
model = MyFonduerModel()
code_paths = [
    "fonduer_subclasses.py",
    "fonduer_model.py",  # will no longer needed once HazyResearch/fonduer#407 is merged.
    "my_fonduer_model.py",
]

import fonduer_model

# Discrinative model