Ejemplo n.º 1
0
def main(
    conn_string,
    gain=False,
    current=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=8,
    log_dir="logs",
    verbose=False,
):
    # Setup initial configuration
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if gain:
        rel_list.append("gain")

    if current:
        rel_list.append("current")

    logger.info(f"=" * 30)
    logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}")

    session = Meta.init(conn_string).Session()

    # Parsing
    start = timer()
    logger.info(f"Starting parsing...")
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    logger.debug(f"Done")
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of Documents: {len(docs)}")
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    if gain:
        Gain = mention_subclass("Gain")
        gain_matcher = get_gain_matcher()
        gain_ngrams = MentionNgrams(n_max=2)
        mentions.append(Gain)
        ngrams.append(gain_ngrams)
        matchers.append(gain_matcher)

    if current:
        Current = mention_subclass("SupplyCurrent")
        current_matcher = get_supply_current_matcher()
        current_ngrams = MentionNgramsCurrent(n_max=3)
        mentions.append(Current)
        ngrams.append(current_ngrams)
        matchers.append(current_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")

    if gain:
        logger.info(f"Total Gain: {session.query(Gain).count()}")

    if current:
        logger.info(f"Total Current: {session.query(Current).count()}")

    cand_classes = []
    if gain:
        GainCand = candidate_subclass("GainCand", [Gain])
        cand_classes.append(GainCand)
    if current:
        CurrentCand = candidate_subclass("CurrentCand", [Current])
        cand_classes.append(CurrentCand)

    candidate_extractor = CandidateExtractor(session, cand_classes)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    logger.info(
        f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}")
    logger.info(
        f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}")
    logger.info(
        f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}")

    logger.info("Done w/ candidate extraction.")
    end = timer()
    logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}")

    # First, check total recall
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[0], is_gain=True),
    #      corpus=dev_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[0], is_gain=True),
    #      corpus=test_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[1], is_gain=False),
    #      corpus=dev_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[1], is_gain=False),
    #      corpus=test_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")

    start = timer()

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cand_classes, parallelism=1)

    if first_time:
        logger.info("Starting featurizer...")
        # Set feature space based on dev set, which we use for training rather
        # than the large train set.
        featurizer.apply(split=1, train=True)
        featurizer.apply(split=0)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    # Serialize feature matrices on first run
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    start = timer()
    logger.info("Labeling training data...")
    #  labeler = Labeler(session, cand_classes)
    #  lfs = []
    #  if gain:
    #      lfs.append(gain_lfs)
    #
    #  if current:
    #      lfs.append(current_lfs)
    #
    #  if first_time:
    #      logger.info("Applying LFs...")
    #      labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
    #  elif re_label:
    #      logger.info("Re-applying LFs...")
    #      labeler.update(split=0, lfs=lfs, parallelism=parallel)
    #
    #  logger.info("Done...")

    #  logger.info("Getting label matrices...")
    #  L_train = labeler.get_label_matrices(train_cands)
    #  logger.info("Done...")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            # Manually create marginals from human annotations
            marginal = []
            dev_gold_entities = get_gold_set(is_gain=(relation == "gain"))
            for c in dev_cands[idx]:
                flag = False
                for entity in cand_to_entity(c, is_gain=(relation == "gain")):
                    if entity in dev_gold_entities:
                        flag = True

                if flag:
                    marginal.append([0.0, 1.0])
                else:
                    marginal.append([1.0, 0.0])

            marginals_dict[relation] = np.array(marginal)

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(
        f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 30
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 500,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.005
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        # only uses dev set as training data, with human annotations
        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    dev_cands[idx],
                    F_dev[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=256,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"opamp_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each relation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=256,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            test_preds,
            test_cands[idx],
            test_docs,
            is_gain=(relation == "gain"),
            num=100,
        )

        # Dump CSV files for analysis
        if relation == "gain":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "gain_test_probs.csv",
                            is_gain=True)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "gain_dev_probs.csv",
                            is_gain=True)

        if relation == "current":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=False)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "current_test_probs.csv",
                            is_gain=False)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "current_dev_probs.csv",
                            is_gain=False)

    end = timer()
    logger.warning(
        f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}"
    )
Ejemplo n.º 2
0
)

from fonduer.candidates.models import Mention

mention_extractor.apply(train_docs, parallelism=PARALLEL)
#num_names = session.query(Presidentname).count()
#num_pobs = session.query(Placeofbirth).count()
print(
    f"Total Mentions: {session.query(Mention).count()}"
)

from fonduer.candidates import CandidateExtractor


candidate_extractor = CandidateExtractor(session, candidate_classes)
candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL)
train_cands = candidate_extractor.get_candidates(split=0)
print(
    f"Number of Candidates: {len(train_cands[0])}"
)

from fonduer.features import Featurizer
import pickle

featurizer = Featurizer(session, candidate_classes)
featurizer.apply(split=0, train=True, parallelism=PARALLEL)
F_train = featurizer.get_feature_matrices(train_cands)

from wiki_table_utils import load_president_gold_labels

gold_file = "data/president_tutorial_gold.csv"
Ejemplo n.º 3
0
def test_cand_gen_cascading_delete():
    """Test cascading the deletion of candidates."""
    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using two cores.")
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])
    mention_extractor.clear_all()
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 93
    assert session.query(Part).count() == 70
    assert session.query(Temp).count() == 23
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(Candidate).count() == 1432
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 23

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete(
        synchronize_session="fetch")
    assert session.query(Candidate).count() == 1431
    assert session.query(PartTemp).count() == 1431

    # Test that deletion of a Candidate does not delete the Mention
    x = session.query(PartTemp).first()
    session.query(PartTemp).filter_by(id=x.id).delete(
        synchronize_session="fetch")
    assert session.query(PartTemp).count() == 1430
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
Ejemplo n.º 4
0
def test_unary_relation_feature_extraction():
    """Test extracting unary candidates from mentions from documents."""
    PARALLEL = 1

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")

    mention_extractor = MentionExtractor(session, [Part], [part_ngrams],
                                         [part_matcher])
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert docs[0].name == "112823"
    assert session.query(Part).count() == 58
    part = session.query(Part).order_by(Part.id).all()[0]
    logger.info(f"Part: {part.context}")

    # Candidate Extraction
    PartRel = candidate_subclass("PartRel", [Part])

    candidate_extractor = CandidateExtractor(session, [PartRel])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    # Featurization based on default feature library
    featurizer = Featurizer(session, [PartRel])

    # Test that featurization default feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_feats = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_textual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_tabular_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_structural_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer = Featurizer(session, [PartRel],
                            feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_visual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)
Ejemplo n.º 5
0
def main(
    conn_string,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    tuner_config = {"max_search": 3}

    em_config = {
        # GENERAL
        "seed": None,
        "verbose": True,
        "show_plots": True,
        # Network
        # The first value is the output dim of the input module (or the sum of
        # the output dims of all the input modules if multitask=True and
        # multiple input modules are provided). The last value is the
        # output dim of the head layer (i.e., the cardinality of the
        # classification task). The remaining values are the output dims of
        # middle layers (if any). The number of middle layers will be inferred
        # from this list.
        #     "layer_out_dims": [10, 2],
        # Input layer configs
        "input_layer_config": {
            "input_relu": False,
            "input_batchnorm": False,
            "input_dropout": 0.0,
        },
        # Middle layer configs
        "middle_layer_config": {
            "middle_relu": False,
            "middle_batchnorm": False,
            "middle_dropout": 0.0,
        },
        # Can optionally skip the head layer completely, for e.g. running baseline
        # models...
        "skip_head": True,
        # GPU
        "device": "cpu",
        # MODEL CLASS
        "resnet18"
        # DATA CONFIG
        "src": "gm",
        # TRAINING
        "train_config": {
            # Display
            "print_every": 1,  # Print after this many epochs
            "disable_prog_bar": False,  # Disable progress bar each epoch
            # Dataloader
            "data_loader_config": {
                "batch_size": 32,
                "num_workers": 8,
                "sampler": None
            },
            # Loss weights
            "loss_weights": [0.5, 0.5],
            # Train Loop
            "n_epochs": 20,
            # 'grad_clip': 0.0,
            "l2": 0.0,
            # "lr": 0.01,
            "validation_metric": "accuracy",
            "validation_freq": 1,
            # Evaluate dev for during training every this many epochs
            # Optimizer
            "optimizer_config": {
                "optimizer": "adam",
                "optimizer_common": {
                    "lr": 0.01
                },
                # Optimizer - SGD
                "sgd_config": {
                    "momentum": 0.9
                },
                # Optimizer - Adam
                "adam_config": {
                    "betas": (0.9, 0.999)
                },
            },
            # Scheduler
            "scheduler_config": {
                "scheduler": "reduce_on_plateau",
                # ['constant', 'exponential', 'reduce_on_plateu']
                # Freeze learning rate initially this many epochs
                "lr_freeze": 0,
                # Scheduler - exponential
                "exponential_config": {
                    "gamma": 0.9
                },  # decay rate
                # Scheduler - reduce_on_plateau
                "plateau_config": {
                    "factor": 0.5,
                    "patience": 1,
                    "threshold": 0.0001,
                    "min_lr": 1e-5,
                },
            },
            # Checkpointer
            "checkpoint": True,
            "checkpoint_config": {
                "checkpoint_min": -1,
                # The initial best score to beat to merit checkpointing
                "checkpoint_runway": 0,
                # Don't start taking checkpoints until after this many epochs
            },
        },
    }

    session = Meta.init(conn_string).Session()

    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    logger.info(f"CWD: {os.getcwd()}")
    dirname = "."

    docs, train_docs, dev_docs, test_docs = parse_dataset(
        session,
        dirname,
        first_time=first_time,
        parallel=parallel,
        max_docs=max_docs)
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")

    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    Thumbnails = mention_subclass("Thumbnails")

    thumbnails_img = MentionFigures()

    class HasFigures(_Matcher):
        def _f(self, m):
            file_path = ""
            for prefix in [
                    "data/train/html/", "data/dev/html/", "data/test/html/"
            ]:
                if os.path.exists(prefix + m.figure.url):
                    file_path = prefix + m.figure.url
            if file_path == "":
                return False
            img = Image.open(file_path)
            width, height = img.size
            min_value = min(width, height)
            return min_value > 50

    mention_extractor = MentionExtractor(session, [Thumbnails],
                                         [thumbnails_img], [HasFigures()],
                                         parallelism=parallel)

    if first_time:
        mention_extractor.apply(docs)

    logger.info("Total Mentions: {}".format(session.query(Mention).count()))

    ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails])

    candidate_extractor = CandidateExtractor(session, [ThumbnailLabel],
                                             throttlers=[None],
                                             parallelism=parallel)

    if first_time:
        candidate_extractor.apply(train_docs, split=0)
        candidate_extractor.apply(dev_docs, split=1)
        candidate_extractor.apply(test_docs, split=2)

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)

    logger.info("Total train candidate:\t{}".format(len(train_cands[0])))
    logger.info("Total dev candidate:\t{}".format(len(dev_cands[0])))
    logger.info("Total test candidate:\t{}".format(len(test_cands[0])))

    fin = open("data/ground_truth.txt", "r")
    gt = set()
    for line in fin:
        gt.add("::".join(line.lower().split()))
    fin.close()

    def LF_gt_label(c):
        doc_file_id = (f"{c[0].context.figure.document.name.lower()}.pdf::"
                       f"{os.path.basename(c[0].context.figure.url.lower())}")
        return TRUE if doc_file_id in gt else FALSE

    ans = {0: 0, 1: 0, 2: 0}

    gt_dev_pb = []
    gt_dev = []
    gt_test = []

    for cand in dev_cands[0]:
        if LF_gt_label(cand) == 1:
            ans[1] += 1
            gt_dev_pb.append([1.0, 0.0])
            gt_dev.append(1.0)
        else:
            ans[2] += 1
            gt_dev_pb.append([0.0, 1.0])
            gt_dev.append(2.0)

    ans = {0: 0, 1: 0, 2: 0}
    for cand in test_cands[0]:
        gt_test.append(LF_gt_label(cand))
        ans[gt_test[-1]] += 1

    batch_size = 64
    input_size = 224

    train_loader = torch.utils.data.DataLoader(
        ImageList(
            data=dev_cands[0],
            label=torch.Tensor(gt_dev_pb),
            transform=transform(input_size),
            prefix="data/dev/html/",
        ),
        batch_size=batch_size,
        shuffle=False,
    )

    dev_loader = torch.utils.data.DataLoader(
        ImageList(
            data=dev_cands[0],
            label=gt_dev,
            transform=transform(input_size),
            prefix="data/dev/html/",
        ),
        batch_size=batch_size,
        shuffle=False,
    )

    test_loader = torch.utils.data.DataLoader(
        ImageList(
            data=test_cands[0],
            label=gt_test,
            transform=transform(input_size),
            prefix="data/test/html/",
        ),
        batch_size=100,
        shuffle=False,
    )

    search_space = {
        "l2": [0.001, 0.0001, 0.00001],  # linear range
        "lr": {
            "range": [0.0001, 0.1],
            "scale": "log"
        },  # log range
    }

    train_config = em_config["train_config"]

    # Defining network parameters
    num_classes = 2
    #  fc_size = 2
    #  hidden_size = 2
    pretrained = True

    # Set CUDA device
    if gpu:
        em_config["device"] = "cuda"
        torch.cuda.set_device(int(gpu))

    # Initializing input module
    input_module = get_cnn("resnet18",
                           pretrained=pretrained,
                           num_classes=num_classes)

    # Initializing model object
    init_args = [[num_classes]]
    init_kwargs = {"input_module": input_module}
    init_kwargs.update(em_config)

    # Searching model
    log_config = {
        "log_dir": os.path.join(dirname, log_dir),
        "run_name": "image"
    }
    searcher = RandomSearchTuner(EndModel, **log_config)

    end_model = searcher.search(
        search_space,
        dev_loader,
        train_args=[train_loader],
        init_args=init_args,
        init_kwargs=init_kwargs,
        train_kwargs=train_config,
        max_search=tuner_config["max_search"],
    )

    # Evaluating model
    scores = end_model.score(
        test_loader,
        metric=["accuracy", "precision", "recall", "f1"],
        break_ties="abstain",
    )
    logger.warning("End Model Score:")
    logger.warning(f"precision: {scores[1]:.3f}")
    logger.warning(f"recall: {scores[2]:.3f}")
    logger.warning(f"f1: {scores[3]:.3f}")
Ejemplo n.º 6
0
def test_incremental(caplog):
    """Run an end-to-end test on incremental additions."""
    caplog.set_level(logging.INFO)

    PARALLEL = 1

    max_docs = 1

    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/dtc114w.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()
    last_docs = corpus_parser.get_documents()

    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 9

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 78
    assert session.query(Candidate).count() == 78

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 78

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 78
    assert session.query(FeatureKey).count() == 496

    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (78, 496)
    assert len(featurizer.get_keys()) == 496

    # Test Dropping FeatureKey
    featurizer.drop_keys(["CORE_e1_LENGTH_1"])
    assert session.query(FeatureKey).count() == 495

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0, lfs=[stg_temp_lfs], train=True, parallelism=PARALLEL)
    assert session.query(Label).count() == 78

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (78, 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)

    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 33

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs, split=0, parallelism=PARALLEL, clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1574

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == 1574
    assert session.query(FeatureKey).count() == 2425
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (1574, 2425)
    assert len(featurizer.get_keys()) == 2425

    # Update Labels
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 1574
    assert session.query(LabelKey).count() == 6
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (1574, 6)

    # Test clear
    featurizer.clear(train=True)
    assert session.query(FeatureKey).count() == 0
Ejemplo n.º 7
0
def test_cand_gen_cascading_delete(database_session):
    """Test cascading the deletion of candidates."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 1
    session = database_session

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session,
        structural=True,
        lingual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(
        session, [Part, Temp], [part_ngrams, temp_ngrams], [part_matcher, temp_matcher]
    )
    mention_extractor.clear_all()
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 93
    assert session.query(Part).count() == 70
    assert session.query(Temp).count() == 23
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp], throttlers=[temp_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1431
    assert session.query(Candidate).count() == 1431
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 23

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete(synchronize_session="fetch")
    assert session.query(Candidate).count() == 1430
    assert session.query(PartTemp).count() == 1430

    # Test that deletion of a Candidate does not delete the Mention
    x = session.query(PartTemp).first()
    candidate = session.query(PartTemp).filter_by(id=x.id).first()
    session.delete(candidate)
    assert session.query(PartTemp).count() == 1429
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
Ejemplo n.º 8
0
def test_cand_gen(caplog):
    """Test extracting candidates from mentions from documents."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parsing
    if os.name == "posix":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    def do_nothing_matcher(fig):
        return True

    max_docs = 10
    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 5548
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 234
    assert session.query(Volt).count() == 107
    assert session.query(Temp).count() == 136
    assert session.query(Fig).count() == 223
    part = session.query(Part).order_by(Part.id).all()[0]
    volt = session.query(Volt).order_by(Volt.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info("Part: {}".format(part.span))
    logger.info("Volt: {}".format(volt.span))
    logger.info("Temp: {}".format(temp.span))

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp, PartVolt],
            throttlers=[
                temp_throttler,
                volt_throttler,
                volt_throttler,
            ],  # Fail, mismatched arity
        )

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],
        )

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt])  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 4141
    assert session.query(PartVolt).count() == 3610
    assert session.query(Candidate).count() == 7751
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(session, [PartTemp, PartVolt],
                                             throttlers=[temp_throttler, None])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 3879
    assert session.query(PartVolt).count() == 3610
    assert session.query(Candidate).count() == 7489
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 3879
    assert session.query(PartVolt).count() == 3266
    assert session.query(Candidate).count() == 7145
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].volts) == 33
    assert len(docs[0].temps) == 24

    # Test that deletion of a Candidate does not delete the Mention
    session.query(PartTemp).delete()
    assert session.query(PartTemp).count() == 0
    assert session.query(Temp).count() == 136
    assert session.query(Part).count() == 234

    # Test deletion of Candidate if Mention is deleted
    assert session.query(PartVolt).count() == 3266
    assert session.query(Volt).count() == 107
    session.query(Volt).delete()
    assert session.query(Volt).count() == 0
    assert session.query(PartVolt).count() == 0
Ejemplo n.º 9
0
def test_cand_gen(caplog):
    """Test extracting candidates from mentions from documents."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parseing
    if os.name == "posix":
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 10
    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    num_docs = session.query(Document).count()
    if num_docs != max_docs:
        logger.info("Parsing...")
        doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
        corpus_parser = Parser(structural=True,
                               lingual=True,
                               visual=True,
                               pdf_path=pdf_path)
        corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 5892
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor = MentionExtractor(
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 234
    assert session.query(Volt).count() == 108
    assert session.query(Temp).count() == 118
    part = session.query(Part).order_by(Part.id).all()[0]
    volt = session.query(Volt).order_by(Volt.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info("Part: {}".format(part.span))
    logger.info("Volt: {}".format(volt.span))
    logger.info("Temp: {}".format(temp.span))

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 3385
    assert session.query(PartVolt).count() == 3364
    assert session.query(Candidate).count() == 6749
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].volts) == 33
    assert len(docs[0].temps) == 18

    # Test that deletion of a Candidate does not delete the Mention
    session.query(PartTemp).delete()
    assert session.query(PartTemp).count() == 0
    assert session.query(Temp).count() == 118
    assert session.query(Part).count() == 234

    # Test deletion of Candidate if Mention is deleted
    assert session.query(PartVolt).count() == 3364
    assert session.query(Volt).count() == 108
    session.query(Volt).delete()
    assert session.query(Volt).count() == 0
    assert session.query(PartVolt).count() == 0
Ejemplo n.º 10
0
    station_throttler = lambda s: ('head' in get_ancestor_tag_names(s) and
                                   'title' in get_ancestor_tag_names(s))
    mentions = prune_duplicate_mentions(session, mentions, Station,
                                        station_throttler)

    # 5.) Candidate Extraction
    print("\n#5 Candidate extraction")
    StationPrice = candidate_classes[0]
    has_candidates = session.query(StationPrice).filter(
        StationPrice.split == 0).count() > 0

    candidate_extractor = CandidateExtractor(session, [StationPrice],
                                             throttlers=throttlers)
    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        if (not has_candidates):
            candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)
        print(
            f"Number of Candidates in split={i}: {session.query(StationPrice).filter(StationPrice.split == i).count()}"
        )

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)
    cands = [train_cands, dev_cands, test_cands]

    # 6.) Featurize candidates
    has_features = session.query(Feature).count() > 0
    print(f"\n#6 Candidate featurization ({not has_features})")

    featurizer = Featurizer(session, [StationPrice],
                            feature_extractors=FeatureExtractor(
Ejemplo n.º 11
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    # Setup initial configuration
    if gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    featurizer = Featurizer(session, cands)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True, parallelism=parallel)
        featurizer.apply(split=1, parallelism=parallel)
        featurizer.apply(split=2, parallelism=parallel)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb"))
        pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb"))
        pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb"))
    else:
        F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb"))
        F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb"))
        F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb"))
    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    labeler = Labeler(session, cands)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs, parallelism=parallel)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()
    if stg_temp_min:
        relation = "stg_temp_min"
        idx = rel_list.index(relation)
        marginals_stg_temp_min = generative_model(L_train[idx])
        disc_model_stg_temp_min = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals_stg_temp_min,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(
            relation,
            disc_model_stg_temp_min,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if stg_temp_max:
        relation = "stg_temp_max"
        idx = rel_list.index(relation)
        marginals_stg_temp_max = generative_model(L_train[idx])
        disc_model_stg_temp_max = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals_stg_temp_max,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(
            relation,
            disc_model_stg_temp_max,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if polarity:
        relation = "polarity"
        idx = rel_list.index(relation)
        marginals_polarity = generative_model(L_train[idx])
        disc_model_polarity = discriminative_model(train_cands[idx],
                                                   F_train[idx],
                                                   marginals_polarity,
                                                   n_epochs=100,
                                                   gpu=gpu)
        best_result, best_b = scoring(
            relation,
            disc_model_polarity,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    if ce_v_max:
        relation = "ce_v_max"
        idx = rel_list.index(relation)

        # Can be uncommented for use in debugging labeling functions
        #  logger.info("Updating labeling function summary...")
        #  keys = labeler.get_keys()
        #  logger.info("Summary for train set labeling functions:")
        #  df = analysis.lf_summary(L_train[idx], lf_names=keys)
        #  logger.info(f"\n{df.to_string()}")
        #
        #  logger.info("Summary for dev set labeling functions:")
        #  df = analysis.lf_summary(
        #      L_dev[idx],
        #      lf_names=keys,
        #      Y=L_dev_gold[idx].todense().reshape(-1).tolist()[0],
        #  )
        #  logger.info(f"\n{df.to_string()}")
        #
        #  logger.info("Summary for test set labeling functions:")
        #  df = analysis.lf_summary(
        #      L_test[idx],
        #      lf_names=keys,
        #      Y=L_test_gold[idx].todense().reshape(-1).tolist()[0],
        #  )
        #  logger.info(f"\n{df.to_string()}")

        marginals_ce_v_max = generative_model(L_train[idx])
        disc_model_ce_v_max = discriminative_model(train_cands[idx],
                                                   F_train[idx],
                                                   marginals_ce_v_max,
                                                   n_epochs=100,
                                                   gpu=gpu)

        # Can be uncommented to view score on development set
        #  best_result, best_b = scoring(
        #      relation,
        #      disc_model_ce_v_max,
        #      dev_cands[idx],
        #      dev_docs,
        #      F_dev[idx],
        #      parts_by_doc,
        #      num=100,
        #  )

        best_result, best_b = scoring(
            relation,
            disc_model_ce_v_max,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")

    # Dump CSV files for CE_V_MAX for digi-key analysis
    if ce_v_max:
        relation = "ce_v_max"
        idx = rel_list.index(relation)
        Y_prob = disc_model_ce_v_max.marginals((test_cands[idx], F_test[idx]))
        dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
        Y_prob = disc_model_ce_v_max.marginals((dev_cands[idx], F_dev[idx]))
        dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

    # Dump CSV files for POLARITY for digi-key analysis
    if polarity:
        relation = "polarity"
        idx = rel_list.index(relation)
        Y_prob = disc_model_polarity.marginals((test_cands[idx], F_test[idx]))
        dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
        Y_prob = disc_model_polarity.marginals((dev_cands[idx], F_dev[idx]))
        dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")
Ejemplo n.º 12
0
def test_e2e_logistic_regression(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parseing
    if os.name == "posix":
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 12

    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    num_docs = session.query(Document).count()
    if num_docs != max_docs:
        logger.info("Parsing...")
        corpus_parser = Parser(structural=True,
                               lingual=True,
                               visual=True,
                               pdf_path=pdf_path)
        corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info("Sentences: {}".format(num_sentences))

    # Divide into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)
    assert len(docs[0].sentences) == 828
    assert len(docs[1].sentences) == 706
    assert len(docs[2].sentences) == 819
    assert len(docs[3].sentences) == 684
    assert len(docs[4].sentences) == 552
    assert len(docs[5].sentences) == 758
    assert len(docs[6].sentences) == 597
    assert len(docs[7].sentences) == 165
    assert len(docs[8].sentences) == 250
    assert len(docs[9].sentences) == 533
    assert len(docs[10].sentences) == 354
    assert len(docs[11].sentences) == 547

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor([Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 127

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor([PartTemp],
                                             throttlers=[temp_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3201
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 420

    train_cands = session.query(PartTemp).filter(PartTemp.split == 0).all()

    featurizer = FeatureAnnotator(PartTemp)
    F_train = featurizer.apply(split=0,
                               replace_key_set=True,
                               parallelism=PARALLEL)
    logger.info(F_train.shape)
    F_dev = featurizer.apply(split=1,
                             replace_key_set=False,
                             parallelism=PARALLEL)
    logger.info(F_dev.shape)
    F_test = featurizer.apply(split=2,
                              replace_key_set=False,
                              parallelism=PARALLEL)
    logger.info(F_test.shape)

    gold_file = "tests/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session,
                         PartTemp,
                         gold_file,
                         ATTRIBUTE,
                         annotator_name="gold")

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = LabelAnnotator(PartTemp, lfs=stg_temp_lfs)
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
    logger.info(L_train.shape)

    load_gold_labels(session, annotator_name="gold", split=0)

    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy))

    load_gold_labels(session, annotator_name="gold", split=1)

    train_marginals = gen_model.marginals(L_train)

    disc_model = LogisticRegression()
    disc_model.train((train_cands, F_train),
                     train_marginals,
                     n_epochs=200,
                     lr=0.001)

    load_gold_labels(session, annotator_name="gold", split=2)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions((test_candidates, F_test))
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]

    labeler = LabelAnnotator(PartTemp, lfs=stg_temp_lfs_2)
    L_train = labeler.apply(split=0,
                            clear=False,
                            update_keys=True,
                            update_values=True,
                            parallelism=PARALLEL)
    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    train_marginals = gen_model.marginals(L_train)

    disc_model = LogisticRegression()
    disc_model.train((train_cands, F_train),
                     train_marginals,
                     n_epochs=200,
                     lr=0.001)

    test_score = disc_model.predictions((test_candidates, F_test))
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
Ejemplo n.º 13
0
def test_e2e(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parsing
    if os.name == "posix":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 12

    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info("Sentences: {}".format(num_sentences))

    # Divide into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)
    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 134
    assert len(mention_extractor.get_mentions()) == 2
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (len(
        mention_extractor.get_mentions(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 70)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3346
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 420

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 3346
    assert (len(
        candidate_extractor.get_candidates(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 1178)

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 61
    assert session.query(FeatureKey).count() == 676

    # Test Dropping FeatureKey
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NFP NN NFP]"])
    assert session.query(FeatureKey).count() == 675
    session.query(Feature).delete()

    featurizer.apply(split=0, train=True, parallelism=1)
    assert session.query(Feature).count() == 3346
    assert session.query(FeatureKey).count() == 3578
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3346, 3578)
    assert len(featurizer.get_keys()) == 3578

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 3407
    assert session.query(FeatureKey).count() == 3578
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 3578)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 3827
    assert session.query(FeatureKey).count() == 3578
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (420, 3578)

    gold_file = "tests/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session,
                         PartTemp,
                         gold_file,
                         ATTRIBUTE,
                         annotator_name="gold")
    assert session.query(GoldLabel).count() == 3827

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == 3346
    assert session.query(LabelKey).count() == 6
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3346, 6)
    assert len(labeler.get_keys()) == 6

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3346, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3346, 1)

    gen_model = LabelLearner(cardinalities=2)
    gen_model.train(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=20,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0, lfs=[stg_temp_lfs_2], parallelism=PARALLEL)
    assert session.query(Label).count() == 3346
    assert session.query(LabelKey).count() == 13
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3346, 13)

    gen_model = LabelLearner(cardinalities=2)
    gen_model.train(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=20,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7

    # Testing LSTM
    disc_model = LSTM()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=5,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7

    # Testing Sparse Logistic Regression
    disc_model = SparseLogisticRegression()
    disc_model.train((train_cands[0], F_train[0]),
                     train_marginals,
                     n_epochs=20,
                     lr=0.001)

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.9)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
Ejemplo n.º 14
0
def test_cand_gen_cascading_delete(caplog):
    """Test cascading the deletion of candidates."""
    caplog.set_level(logging.INFO)

    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using two cores.")
        PARALLEL = 2  # Travis only gives 2 cores

    def do_nothing_matcher(fig):
        return True

    max_docs = 10
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 5548
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Mention).count() == 370
    assert session.query(Part).count() == 234
    assert session.query(Temp).count() == 136
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 3879
    assert session.query(Candidate).count() == 3879
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].temps) == 24

    # Delete from parent class should cascade to child
    x = session.query(Candidate).first()
    session.query(Candidate).filter_by(id=x.id).delete()
    assert session.query(PartTemp).count() == 3878
    assert session.query(Candidate).count() == 3878

    # Clearing Mentions should also delete Candidates
    mention_extractor.clear()
    assert session.query(Mention).count() == 0
    assert session.query(Part).count() == 0
    assert session.query(Temp).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(Candidate).count() == 0
Ejemplo n.º 15
0
def test_cand_gen():
    """Test extracting candidates from mentions from documents."""
    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using as many cores as available (PARALLEL=32)")
        PARALLEL = 32

    def do_nothing_matcher(fig):
        return True

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 70
    assert session.query(Volt).count() == 33
    assert session.query(Temp).count() == 23
    assert session.query(Fig).count() == 31
    part = session.query(Part).order_by(Part.id).all()[0]
    volt = session.query(Volt).order_by(Volt.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Volt: {volt.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp, PartVolt],
            throttlers=[
                temp_throttler,
                volt_throttler,
                volt_throttler,
            ],  # Fail, mismatched arity
        )

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],
        )

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt])  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1610
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3920
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(PartVolt).count() == 0

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(session, [PartTemp, PartVolt],
                                             throttlers=[temp_throttler, None])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3742
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 1993
    assert session.query(Candidate).count() == 3425
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].volts) == 33
    assert len(docs[0].temps) == 23

    # Test that deletion of a Candidate does not delete the Mention
    session.query(PartTemp).delete(synchronize_session="fetch")
    assert session.query(PartTemp).count() == 0
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Test deletion of Candidate if Mention is deleted
    assert session.query(PartVolt).count() == 1993
    assert session.query(Volt).count() == 33
    session.query(Volt).delete(synchronize_session="fetch")
    assert session.query(Volt).count() == 0
    assert session.query(PartVolt).count() == 0
Ejemplo n.º 16
0
def test_incremental():
    """Run an end-to-end test on incremental additions."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 1

    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/dtc114w.html"
    pdf_path = "tests/data/pdf/dtc114w.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    docs = corpus_parser.get_documents()
    last_docs = corpus_parser.get_documents()

    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 11
    assert session.query(Temp).count() == 8

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp],
                                             throttlers=[temp_throttler])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 70
    assert session.query(Candidate).count() == 70

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 70

    # Featurization
    featurizer = Featurizer(session, [PartTemp])

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 70
    assert session.query(FeatureKey).count() == 512

    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (70, 512)
    assert len(featurizer.get_keys()) == 512

    # Test Dropping FeatureKey
    featurizer.drop_keys(["CORE_e1_LENGTH_1"])
    assert session.query(FeatureKey).count() == 512

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = Labeler(session, [PartTemp])

    labeler.apply(split=0,
                  lfs=[stg_temp_lfs],
                  train=True,
                  parallelism=PARALLEL)
    assert session.query(Label).count() == 70

    # Only 5 because LF_operating_row doesn't apply to the first test doc
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (70, 5)
    assert len(labeler.get_keys()) == 5

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser.apply(doc_preprocessor, pdf_path=pdf_path, clear=False)

    assert len(corpus_parser.get_documents()) == 2

    new_docs = corpus_parser.get_last_documents()

    assert len(new_docs) == 1
    assert new_docs[0].name == "112823"

    # Get mentions from just the new docs
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)
    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Test if existing mentions are skipped.
    mention_extractor.apply(new_docs, parallelism=PARALLEL, clear=False)
    assert session.query(Part).count() == 81
    assert session.query(Temp).count() == 31

    # Just run candidate extraction and assign to split 0
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1502

    # Test if existing candidates are skipped.
    candidate_extractor.apply(new_docs,
                              split=0,
                              parallelism=PARALLEL,
                              clear=False)
    train_cands = candidate_extractor.get_candidates(split=0)
    assert len(train_cands) == 1
    assert len(train_cands[0]) == 1502

    # Update features
    featurizer.update(new_docs, parallelism=PARALLEL)
    assert session.query(Feature).count() == 1502
    assert session.query(FeatureKey).count() == 2573
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (1502, 2573)
    assert len(featurizer.get_keys()) == 2573

    # Update LF_storage_row. Now it always returns ABSTAIN.
    @labeling_function(name="LF_storage_row")
    def LF_storage_row_updated(c):
        return ABSTAIN

    stg_temp_lfs = [
        LF_storage_row_updated,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    # Update Labels
    labeler.update(docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    labeler.update(new_docs, lfs=[stg_temp_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 1502
    # Only 5 because LF_storage_row doesn't apply to any doc (always ABSTAIN)
    assert session.query(LabelKey).count() == 5
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (1502, 5)

    # Test clear
    featurizer.clear(train=True)
    assert session.query(FeatureKey).count() == 0
Ejemplo n.º 17
0
def main(
    conn_string,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(
        session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs
    )
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")

    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    start = timer()

    Thumbnails = mention_subclass("Thumbnails")

    thumbnails_img = MentionFigures()

    class HasFigures(_Matcher):
        def _f(self, m):
            file_path = ""
            for prefix in [
                f"{dirname}/data/train/html/",
                f"{dirname}/data/dev/html/",
                f"{dirname}/data/test/html/",
            ]:
                if os.path.exists(prefix + m.figure.url):
                    file_path = prefix + m.figure.url
            if file_path == "":
                return False
            img = Image.open(file_path)
            width, height = img.size
            min_value = min(width, height)
            return min_value > 50

    mention_extractor = MentionExtractor(
        session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel
    )

    if first_time:
        mention_extractor.apply(docs)

    logger.info("Total Mentions: {}".format(session.query(Mention).count()))

    ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails])

    candidate_extractor = CandidateExtractor(
        session, [ThumbnailLabel], throttlers=[None], parallelism=parallel
    )

    if first_time:
        candidate_extractor.apply(train_docs, split=0)
        candidate_extractor.apply(dev_docs, split=1)
        candidate_extractor.apply(test_docs, split=2)

    train_cands = candidate_extractor.get_candidates(split=0)
    # Sort the dev_cands, which are used for training, for deterministic behavior
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2)

    end = timer()
    logger.warning(f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info("Total train candidate:\t{}".format(len(train_cands[0])))
    logger.info("Total dev candidate:\t{}".format(len(dev_cands[0])))
    logger.info("Total test candidate:\t{}".format(len(test_cands[0])))

    fin = open(f"{dirname}/data/ground_truth.txt", "r")
    gt = set()
    for line in fin:
        gt.add("::".join(line.lower().split()))
    fin.close()

    # Labeling
    start = timer()

    def LF_gt_label(c):
        doc_file_id = (
            f"{c[0].context.figure.document.name.lower()}.pdf::"
            f"{os.path.basename(c[0].context.figure.url.lower())}"
        )
        return TRUE if doc_file_id in gt else FALSE

    gt_dev = [LF_gt_label(cand) for cand in dev_cands[0]]
    gt_test = [LF_gt_label(cand) for cand in test_cands[0]]

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    batch_size = 64
    input_size = 224
    K = 2

    emmental.init(log_dir=Meta.log_path, config=emmental_config)

    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"
    ] = DauphinScheduler(augment_k=K, enlarge=1)

    train_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "train",
        prob_label=True,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        transform_cls=Augmentation(2),
        k=K,
    )

    val_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "valid",
        prob_label=False,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        k=1,
    )

    test_dataset = ThumbnailDataset(
        "Thumbnail",
        test_cands[0],
        gt_test,
        "test",
        prob_label=False,
        prefix=f"{dirname}/data/test/html/",
        input_size=input_size,
        k=1,
    )

    dataloaders = []

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=train_dataset,
            split="train",
            shuffle=True,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=val_dataset,
            split="valid",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=test_dataset,
            split="test",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    model = EmmentalModel(name=f"Thumbnail")
    model.add_task(
        create_task("Thumbnail", n_class=2, model="resnet18", pretrained=True)
    )

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)

    logger.warning("Model Score:")
    logger.warning(f"precision: {scores['Thumbnail/Thumbnail/test/precision']:.3f}")
    logger.warning(f"recall: {scores['Thumbnail/Thumbnail/test/recall']:.3f}")
    logger.warning(f"f1: {scores['Thumbnail/Thumbnail/test/f1']:.3f}")
Ejemplo n.º 18
0
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    PARALLEL = 4

    max_docs = 12

    fonduer.init_logging(
        log_dir="log_folder",
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (len(
        mention_extractor.get_mentions(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 70)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (len(
        candidate_extractor.get_candidates(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 1432)

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                           candidate_classes=[PartTemp, PartVolt])
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    gen_model = LabelModel()
    gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    # Collect word counter
    word_counter = collect_word_counter(train_cands)

    emmental.init(fonduer.Meta.log_path)

    # Training config
    config = {
        "meta_config": {
            "verbose": False
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0],
                               emb_layer.word2id, 2),
        split="test",
        batch_size=100,
        shuffle=False,
    )

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    gen_model = LabelModel()
    gen_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = gen_model.predict_proba(L_train[0])

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7
Ejemplo n.º 19
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cands, parallelism=1)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True)
        featurizer.apply(split=1)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    # Using parallelism = 1 for deterministic behavior.
    labeler = Labeler(session, cands, parallelism=1)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            marginals_dict[relation] = generative_model(L_train[idx])

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 17
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    train_cands[idx],
                    F_train[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=100,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"transistor_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each rlation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=100,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            relation,
            test_preds,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

        # Dump CSV files for CE_V_MAX for digi-key analysis
        if relation == "ce_v_max":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

        # Dump CSV files for POLARITY for digi-key analysis
        if relation == "polarity":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")
Ejemplo n.º 20
0
def test_e2e(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 12

    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 147
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (
        len(
            mention_extractor.get_mentions(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 70
    )

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]
    )

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3684
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 72
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 448
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3684
    assert (
        len(
            candidate_extractor.get_candidates(
                docs=[session.query(Document).filter(Document.name == "112823").first()]
            )[0]
        )
        == 1496
    )

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 225
    assert session.query(FeatureKey).count() == 1179

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NFP NN NFP]"])
    assert session.query(FeatureKey).count() == 1178

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey)
        .filter(FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]")
        .one()
        .candidate_classes
    ) == {"part_temp", "part_volt"}
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name == "DDL_e1_LEMMA_SEQ_[bc182]"
    ).one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1178
    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"], candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1177
    session.query(Feature).delete()
    session.query(FeatureKey).delete()

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6669
    assert session.query(FeatureKey).count() == 4161
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3684, 4161)
    assert F_train[1].shape == (2985, 4161)
    assert len(featurizer.get_keys()) == 4161

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6894
    assert session.query(FeatureKey).count() == 4161
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (72, 4161)
    assert F_dev[1].shape == (153, 4161)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8486
    assert session.query(FeatureKey).count() == 4161
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (448, 4161)
    assert F_test[1].shape == (1144, 4161)

    gold_file = "tests/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session, PartTemp, gold_file, ATTRIBUTE, annotator_name="gold")
    assert session.query(GoldLabel).count() == 4204
    load_hardware_labels(session, PartVolt, gold_file, ATTRIBUTE, annotator_name="gold")
    assert session.query(GoldLabel).count() == 8486

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    labeler = Labeler(session, [PartTemp, PartVolt])

    with pytest.raises(ValueError):
        labeler.apply(split=0, lfs=stg_temp_lfs, train=True, parallelism=PARALLEL)

    labeler.apply(
        split=0, lfs=[stg_temp_lfs, ce_v_max_lfs], train=True, parallelism=PARALLEL
    )
    assert session.query(Label).count() == 6669
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3684, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3684, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3684, 1)

    gen_model = LabelModel(k=2)
    gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0, lfs=[stg_temp_lfs_2, ce_v_max_lfs], parallelism=PARALLEL)
    assert session.query(Label).count() == 6669
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3684, 16)

    gen_model = LabelModel(k=2)
    gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

    train_marginals = gen_model.predict_proba(L_train[0])[:, 1]

    disc_model = LogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    disc_model = LSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse Logistic Regression
    disc_model = SparseLogisticRegression()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=20, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing Sparse LSTM
    disc_model = SparseLSTM()
    disc_model.train(
        (train_cands[0], F_train[0]), train_marginals, n_epochs=5, lr=0.001
    )

    test_score = disc_model.predictions((test_cands[0], F_test[0]), b=0.6)
    true_pred = [test_cands[0][_] for _ in np.nditer(np.where(test_score > 0))]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc
    )

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7
Ejemplo n.º 21
0
def main(
    conn_string,
    gain=False,
    current=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    gpu=None,
    parallel=8,
    log_dir="logs",
    verbose=False,
):
    # Setup initial configuration
    if gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if gain:
        rel_list.append("gain")

    if current:
        rel_list.append("current")

    logger.info(f"=" * 30)
    logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}")

    session = Meta.init(conn_string).Session()

    # Parsing
    start = timer()
    logger.info(f"Starting parsing...")
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    logger.debug(f"Done")
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of Documents: {len(docs)}")
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    if gain:
        Gain = mention_subclass("Gain")
        gain_matcher = get_gain_matcher()
        gain_ngrams = MentionNgrams(n_max=2)
        mentions.append(Gain)
        ngrams.append(gain_ngrams)
        matchers.append(gain_matcher)

    if current:
        Current = mention_subclass("SupplyCurrent")
        current_matcher = get_supply_current_matcher()
        current_ngrams = MentionNgramsCurrent(n_max=3)
        mentions.append(Current)
        ngrams.append(current_ngrams)
        matchers.append(current_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")

    if gain:
        logger.info(f"Total Gain: {session.query(Gain).count()}")

    if current:
        logger.info(f"Total Current: {session.query(Current).count()}")

    cand_classes = []
    if gain:
        GainCand = candidate_subclass("GainCand", [Gain])
        cand_classes.append(GainCand)
    if current:
        CurrentCand = candidate_subclass("CurrentCand", [Current])
        cand_classes.append(CurrentCand)

    candidate_extractor = CandidateExtractor(session, cand_classes)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)
    logger.info(
        f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}")
    logger.info(
        f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}")
    logger.info(
        f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}")

    logger.info("Done w/ candidate extraction.")
    end = timer()
    logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}")

    # First, check total recall
    #  result = entity_level_scores(dev_cands[0], corpus=dev_docs)
    #  logger.info(f"Gain Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(test_cands[0], corpus=test_docs)
    #  logger.info(f"Gain Total Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #
    #  result = entity_level_scores(dev_cands[1], corpus=dev_docs, is_gain=False)
    #  logger.info(f"Current Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(test_cands[1], corpus=test_docs, is_gain=False)
    #  logger.info(f"Current Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")

    start = timer()
    featurizer = Featurizer(session, cand_classes)

    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True, parallelism=parallel)
        featurizer.apply(split=1, parallelism=parallel)
        featurizer.apply(split=2, parallelism=parallel)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    # Serialize feature matrices on first run
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        pickle.dump(F_train, open(os.path.join(dirname, "F_train.pkl"), "wb"))
        pickle.dump(F_dev, open(os.path.join(dirname, "F_dev.pkl"), "wb"))
        pickle.dump(F_test, open(os.path.join(dirname, "F_test.pkl"), "wb"))
    else:
        F_train = pickle.load(open(os.path.join(dirname, "F_train.pkl"), "rb"))
        F_dev = pickle.load(open(os.path.join(dirname, "F_dev.pkl"), "rb"))
        F_test = pickle.load(open(os.path.join(dirname, "F_test.pkl"), "rb"))
    logger.info("Done.")

    start = timer()
    logger.info("Labeling training data...")
    labeler = Labeler(session, cand_classes)
    lfs = []
    if gain:
        lfs.append(gain_lfs)

    if current:
        lfs.append(current_lfs)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
    elif re_label:
        logger.info("Re-applying LFs...")
        labeler.update(split=0, lfs=lfs, parallelism=parallel)

    logger.info("Done...")

    logger.info("Getting label matrices...")
    L_train = labeler.get_label_matrices(train_cands)
    logger.info("Done...")

    end = timer()
    logger.warning(
        f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}")

    if gain:
        relation = "gain"
        idx = rel_list.index(relation)

        logger.info("Score Gain.")
        dev_gold_entities = get_gold_set(is_gain=True)
        L_dev_gt = []
        for c in dev_cands[idx]:
            flag = FALSE
            for entity in cand_to_entity(c, is_gain=True):
                if entity in dev_gold_entities:
                    flag = TRUE
            L_dev_gt.append(flag)

        marginals = generative_model(L_train[idx])
        disc_models = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals,
            X_dev=(dev_cands[idx], F_dev[idx]),
            Y_dev=L_dev_gt,
            n_epochs=500,
            gpu=gpu,
        )
        best_result, best_b = scoring(disc_models,
                                      test_cands[idx],
                                      test_docs,
                                      F_test[idx],
                                      num=50)

        print_scores(relation, best_result, best_b)

        logger.info("Output CSV files for Opo and Digi-key Analysis.")
        Y_prob = disc_models.marginals((train_cands[idx], F_train[idx]))
        output_csv(train_cands[idx], Y_prob, is_gain=True)

        Y_prob = disc_models.marginals((test_cands[idx], F_test[idx]))
        output_csv(test_cands[idx], Y_prob, is_gain=True, append=True)
        dump_candidates(test_cands[idx],
                        Y_prob,
                        "gain_test_probs.csv",
                        is_gain=True)

        Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx]))
        output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True)
        dump_candidates(dev_cands[idx],
                        Y_prob,
                        "gain_dev_probs.csv",
                        is_gain=True)

    if current:
        relation = "current"
        idx = rel_list.index(relation)

        logger.info("Score Current.")
        dev_gold_entities = get_gold_set(is_gain=False)
        L_dev_gt = []
        for c in dev_cands[idx]:
            flag = FALSE
            for entity in cand_to_entity(c, is_gain=False):
                if entity in dev_gold_entities:
                    flag = TRUE
            L_dev_gt.append(flag)

        marginals = generative_model(L_train[idx])

        disc_models = discriminative_model(
            train_cands[idx],
            F_train[idx],
            marginals,
            X_dev=(dev_cands[idx], F_dev[idx]),
            Y_dev=L_dev_gt,
            n_epochs=100,
            gpu=gpu,
        )
        best_result, best_b = scoring(disc_models,
                                      test_cands[idx],
                                      test_docs,
                                      F_test[idx],
                                      is_gain=False,
                                      num=50)

        print_scores(relation, best_result, best_b)

        logger.info("Output CSV files for Opo and Digi-key Analysis.")
        # Dump CSV files for digi-key analysis and Opo comparison
        Y_prob = disc_models.marginals((train_cands[idx], F_train[idx]))
        output_csv(train_cands[idx], Y_prob, is_gain=False)

        Y_prob = disc_models.marginals((test_cands[idx], F_test[idx]))
        output_csv(test_cands[idx], Y_prob, is_gain=False, append=True)
        dump_candidates(test_cands[idx],
                        Y_prob,
                        "current_test_probs.csv",
                        is_gain=False)

        Y_prob = disc_models.marginals((dev_cands[idx], F_dev[idx]))
        output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True)
        dump_candidates(dev_cands[idx],
                        Y_prob,
                        "current_dev_probs.csv",
                        is_gain=False)

    end = timer()
    logger.warning(
        f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}"
    )
Ejemplo n.º 22
0
def test_too_many_clients_error_should_not_happen():
    """Too many clients error should not happens."""
    if platform == "darwin":
        PARALLEL = 1
    else:
        PARALLEL = 32
    logger.info("Parallel: {PARALLEL}")

    def do_nothing_matcher(fig):
        return True

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt])  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
    candidate_extractor.clear_all(split=0)

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(session, [PartTemp, PartVolt],
                                             throttlers=[temp_throttler, None])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
Ejemplo n.º 23
0
def test_e2e(database_session):
    """Run an end-to-end test on documents of the hardware domain."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 12

    fonduer.init_logging(
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = database_session

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual_parser=PdfVisualParser(pdf_path),
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions(docs)) == 3

    # Candidate Extraction
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)

    # Candidate lists should be deterministically sorted.
    assert ("112823::implicit_span_mention:11059:11065:part_expander:0" ==
            train_cands[0][0][0].context.get_stable_id())
    assert ("112823::implicit_span_mention:2752:2754:temp_expander:0" ==
            train_cands[0][0][1].context.get_stable_id())

    assert len(train_cands) == 2
    assert len(candidate_extractor.get_candidates(docs)) == 2

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    num_feature_keys = session.query(FeatureKey).count()
    assert num_feature_keys == 1278

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["BASIC_e0_CONTAINS_WORDS_[BC182]"])
    assert session.query(FeatureKey).count() == num_feature_keys - 1

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == num_feature_keys - 1

    # Inserting the removed key
    featurizer.upsert_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                           candidate_classes=[PartTemp, PartVolt])
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e0_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    assert session.query(FeatureKey).count() == num_feature_keys - 1
    # Removing the key again
    featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e0_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == num_feature_keys - 2
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    # the number of Features should equals to the total number of train candidates
    num_features = session.query(Feature).count()
    assert num_features == len(train_cands[0]) + len(train_cands[1])
    num_feature_keys = session.query(FeatureKey).count()
    assert num_feature_keys == 4629
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
    assert F_train[1].shape == (len(train_cands[1]), num_feature_keys)
    assert len(featurizer.get_keys()) == num_feature_keys

    featurizer.apply(split=1, parallelism=PARALLEL)
    # the number of Features should increate by the total number of dev candidates
    num_features += len(dev_cands[0]) + len(dev_cands[1])
    assert session.query(Feature).count() == num_features

    assert session.query(FeatureKey).count() == num_feature_keys
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (len(dev_cands[0]), num_feature_keys)
    assert F_dev[1].shape == (len(dev_cands[1]), num_feature_keys)

    featurizer.apply(split=2, parallelism=PARALLEL)
    # the number of Features should increate by the total number of test candidates
    num_features += len(test_cands[0]) + len(test_cands[1])
    assert session.query(Feature).count() == num_features
    assert session.query(FeatureKey).count() == num_feature_keys
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (len(test_cands[0]), num_feature_keys)
    assert F_test[1].shape == (len(test_cands[1]), num_feature_keys)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    # This should raise an error, since gold labels are not yet loaded.
    with pytest.raises(ValueError):
        _ = labeler.get_gold_labels(train_cands, annotator="gold")

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    # All candidates should now be gold-labeled.
    assert session.query(GoldLabel).count() == session.query(Candidate).count()

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == len(train_cands[0]) + len(
        train_cands[1])
    num_label_keys = session.query(LabelKey).count()
    assert num_label_keys == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (len(train_cands[0]), num_label_keys)
    assert L_train[1].shape == (len(train_cands[1]), num_label_keys)
    assert len(labeler.get_keys()) == num_label_keys

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == num_label_keys - 1

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (len(train_cands[0]), 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (len(train_cands[0]), 1)

    label_model = LabelModel(cardinality=2)
    label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    # Collect word counter
    word_counter = collect_word_counter(train_cands)

    emmental.init(fonduer.Meta.log_path)

    # Training config
    config = {
        "meta_config": {
            "verbose": False
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0],
                               emb_layer.word2id, 2),
        split="test",
        batch_size=100,
        shuffle=False,
    )

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == len(train_cands[0]) + len(
        train_cands[1])
    num_label_keys = session.query(LabelKey).count()
    assert num_label_keys == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (len(train_cands[0]), num_label_keys)

    label_model = LabelModel(cardinality=2)
    label_model.fit(L_train=L_train[0], n_epochs=500, seed=1234, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    valid_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            np.argmax(train_marginals, axis=1),
            train_idxs,
        ),
        split="valid",
        batch_size=100,
        shuffle=False,
    )

    # Testing STL LogisticRegression
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(
        ATTRIBUTE,
        2,
        F_train[0].shape[1],
        2,
        emb_layer,
        model="LogisticRegression",
        mode="STL",
    )

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing STL LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM",
                        mode="STL")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing MTL LogisticRegression
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(
        ATTRIBUTE,
        2,
        F_train[0].shape[1],
        2,
        emb_layer,
        model="LogisticRegression",
        mode="MTL",
    )

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM",
                        mode="MTL")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7
Ejemplo n.º 24
0
#sents = session.query(Sentence).all()

from fonduer.candidates import CandidateExtractor, MentionExtractor, MentionNgrams
from fonduer.candidates.models import mention_subclass, candidate_subclass
from fonduer.candidates.matchers import RegexMatchSpan, Union, LambdaFunctionMatcher
from dataset_utils import price_match

# Defining ngrams for candidates
extraction_name = 'price'
ngrams = MentionNgrams(n_max=5)

# Define matchers
matchers = LambdaFunctionMatcher(func=price_match)

# Getting candidates
PriceMention = mention_subclass("PriceMention")
mention_extractor = MentionExtractor(
        session, [PriceMention], [ngrams], [matchers]
    )
mention_extractor.clear_all()
mention_extractor.apply(docs, parallelism=parallelism)
candidate_class = candidate_subclass("Price", [PriceMention])
candidate_extractor = CandidateExtractor(session, [candidate_class])

# Applying candidate extractors
candidate_extractor.apply(docs, split=0, parallelism=parallelism)
print("==============================")
print(f"Candidate extraction results for {postgres_db_name}:")
print("Number of candidates:", session.query(candidate_class).filter(candidate_class.split == 0).count())
print("==============================")
Ejemplo n.º 25
0
def test_binary_relation_feature_extraction():
    """Test extracting candidates from mentions from documents."""
    PARALLEL = 1

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session,
                           structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgrams(n_max=1)
    temp_ngrams = MentionNgrams(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")

    mention_extractor = MentionExtractor(session, [Part, Temp],
                                         [part_ngrams, temp_ngrams],
                                         [part_matcher, temp_matcher])
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert docs[0].name == "112823"
    assert session.query(Part).count() == 58
    assert session.query(Temp).count() == 16
    part = session.query(Part).order_by(Part.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])

    candidate_extractor = CandidateExtractor(session, [PartTemp])

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    n_cands = session.query(PartTemp).count()

    # Featurization based on default feature library
    featurizer = Featurizer(session, [PartTemp])

    # Test that featurization default feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_feats = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Example feature extractor
    def feat_ext(candidates):
        candidates = candidates if isinstance(candidates,
                                              list) else [candidates]
        for candidate in candidates:
            yield candidate.id, f"cand_id_{candidate.id}", 1

    # Featurization with one extra feature extractor
    feature_extractors = FeatureExtractor(customize_feature_funcs=[feat_ext])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_default_w_customized_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Example spurious feature extractor
    def bad_feat_ext(candidates):
        raise RuntimeError()

    # Featurization with a spurious feature extractor
    feature_extractors = FeatureExtractor(
        customize_feature_funcs=[bad_feat_ext])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization default feature library with one extra feature extractor
    logger.info("Featurizing with a spurious feature extractor...")
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    featurizer.clear(train=True)

    # Featurization with only textual feature
    feature_extractors = FeatureExtractor(features=["textual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization textual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_textual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only tabular feature
    feature_extractors = FeatureExtractor(features=["tabular"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization tabular feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_tabular_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only structural feature
    feature_extractors = FeatureExtractor(features=["structural"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization structural feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_structural_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    # Featurization with only visual feature
    feature_extractors = FeatureExtractor(features=["visual"])
    featurizer = Featurizer(session, [PartTemp],
                            feature_extractors=feature_extractors)

    # Test that featurization visual feature library
    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    n_visual_features = session.query(FeatureKey).count()
    featurizer.clear(train=True)

    assert (n_default_feats == n_textual_features + n_tabular_features +
            n_structural_features + n_visual_features)

    assert n_default_w_customized_features == n_default_feats + n_cands