示例#1
0
Org_Fig = candidate_subclass('Org_Fig', ['product','figure'])

from fonduer import HTMLPreprocessor, OmniParser

docs_path = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/html/'
pdf_path = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/pdf/'



# load gold label
from tutorials.organic_synthesis_figures.organic_utils import load_organic_labels

gold_file = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/organic_gold.csv'
load_organic_labels(session, Org_Fig, gold_file, ATTRIBUTE ,annotator_name='gold')

from tutorials.organic_synthesis_figures import organic_lfs


from fonduer import BatchLabelAnnotator
#
labeler = BatchLabelAnnotator(Org_Fig, lfs = organic_lfs.org_fig_lfs)
L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
# print(L_train.shape)
#
# L_train.get_candidate(session, 0)

# # Applying the Labeling Functions
from fonduer import load_gold_labels
L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)
L_train.lf_stats(L_gold_train)
def test_e2e(caplog):
    """Run an end-to-end test on 20 documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    PARALLEL = 2
    max_docs = 12

    session = SnorkelSession()

    Part_Attr = candidate_subclass('Part_Attr', ['part', 'attr'])

    docs_path = 'tests/e2e/data/html/'
    pdf_path = 'tests/e2e/data/pdf/'

    doc_preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = OmniParser(
        structural=True, lingual=True, visual=True, pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_phrases = session.query(Phrase).count()
    logger.info("Phrases: {}".format(num_phrases))
    #  assert num_phrases == 20

    # Divid into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    attr_matcher = RegexMatchSpan(
        rgx=r'(?:[1][5-9]|20)[05]', longest_match_only=False)

    ### Transistor Naming Conventions as Regular Expressions ###
    eeca_rgx = r'([ABC][A-Z][WXYZ]?[0-9]{3,5}(?:[A-Z]){0,5}[0-9]?[A-Z]?(?:-[A-Z0-9]{1,7})?(?:[-][A-Z0-9]{1,2})?(?:\/DG)?)'
    jedec_rgx = r'(2N\d{3,4}[A-Z]{0,5}[0-9]?[A-Z]?)'
    jis_rgx = r'(2S[ABCDEFGHJKMQRSTVZ]{1}[\d]{2,4})'
    others_rgx = r'((?:NSVBC|SMBT|MJ|MJE|MPS|MRF|RCA|TIP|ZTX|ZT|ZXT|TIS|TIPL|DTC|MMBT|SMMBT|PZT|FZT|STD|BUV|PBSS|KSC|CXT|FCX|CMPT){1}[\d]{2,4}[A-Z]{0,5}(?:-[A-Z0-9]{0,6})?(?:[-][A-Z0-9]{0,1})?)'

    part_rgx = '|'.join([eeca_rgx, jedec_rgx, jis_rgx, others_rgx])
    part_rgx_matcher = RegexMatchSpan(rgx=part_rgx, longest_match_only=True)

    def get_digikey_parts_set(path):
        """
        Reads in the digikey part dictionary and yeilds each part.
        """
        all_parts = set()
        with open(path, "r") as csvinput:
            reader = csv.reader(csvinput)
            for line in reader:
                (part, url) = line
                all_parts.add(part)
        return all_parts

    ### Dictionary of known transistor parts ###
    dict_path = 'tests/e2e/data/digikey_part_dictionary.csv'
    part_dict_matcher = DictionaryMatch(d=get_digikey_parts_set(dict_path))

    def common_prefix_length_diff(str1, str2):
        for i in range(min(len(str1), len(str2))):
            if str1[i] != str2[i]:
                return min(len(str1), len(str2)) - i
        return 0

    def part_file_name_conditions(attr):
        file_name = attr.sentence.document.name
        if len(file_name.split('_')) != 2: return False
        if attr.get_span()[0] == '-': return False
        name = attr.get_span().replace('-', '')
        return any(char.isdigit() for char in name) and any(
            char.isalpha() for char in name) and common_prefix_length_diff(
                file_name.split('_')[1], name) <= 2

    add_rgx = '^[A-Z0-9\-]{5,15}$'

    part_file_name_lambda_matcher = LambdaFunctionMatcher(
        func=part_file_name_conditions)
    part_file_name_matcher = Intersect(
        RegexMatchSpan(rgx=add_rgx, longest_match_only=True),
        part_file_name_lambda_matcher)

    part_matcher = Union(part_rgx_matcher, part_dict_matcher,
                         part_file_name_matcher)

    part_ngrams = OmniNgramsPart(parts_by_doc=None, n_max=3)
    attr_ngrams = OmniNgramsTemp(n_max=2)

    def stg_temp_filter(c):
        (part, attr) = c
        if same_table((part, attr)):
            return (is_horz_aligned((part, attr)) or is_vert_aligned(
                (part, attr)))
        return True

    candidate_filter = stg_temp_filter

    candidate_extractor = CandidateExtractor(
        Part_Attr, [part_ngrams, attr_ngrams], [part_matcher, attr_matcher],
        candidate_filter=candidate_filter)

    candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL)

    train_cands = session.query(Part_Attr).filter(Part_Attr.split == 0).all()
    logger.info("Number of candidates: {}".format(len(train_cands)))

    for i, docs in enumerate([dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i + 1)
        logger.info("Number of candidates: {}".format(
            session.query(Part_Attr).filter(Part_Attr.split == i + 1).count()))

    featurizer = BatchFeatureAnnotator(Part_Attr)
    F_train = featurizer.apply(
        split=0, replace_key_set=True, parallelism=PARALLEL)
    logger.info(F_train.shape)
    F_dev = featurizer.apply(
        split=1, replace_key_set=False, parallelism=PARALLEL)
    logger.info(F_dev.shape)
    F_test = featurizer.apply(
        split=2, replace_key_set=False, parallelism=PARALLEL)
    logger.info(F_test.shape)

    gold_file = 'tests/e2e/data/hardware_tutorial_gold.csv'
    load_hardware_labels(
        session, Part_Attr, gold_file, ATTRIBUTE, annotator_name='gold')

    def LF_storage_row(c):
        return 1 if 'storage' in get_row_ngrams(c.attr) else 0

    def LF_temperature_row(c):
        return 1 if 'temperature' in get_row_ngrams(c.attr) else 0

    def LF_operating_row(c):
        return 1 if 'operating' in get_row_ngrams(c.attr) else 0

    def LF_tstg_row(c):
        return 1 if overlap(['tstg', 'stg', 'ts'], list(
            get_row_ngrams(c.attr))) else 0

    def LF_to_left(c):
        return 1 if 'to' in get_left_ngrams(c.attr, window=2) else 0

    def LF_negative_number_left(c):
        return 1 if any([
            re.match(r'-\s*\d+', ngram)
            for ngram in get_left_ngrams(c.attr, window=4)
        ]) else 0

    stg_temp_lfs = [
        LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row,
        LF_to_left, LF_negative_number_left
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs)
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
    logger.info(L_train.shape)

    L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)

    gen_model = GenerativeModel()
    gen_model.train(
        L_train,
        epochs=500,
        decay=0.9,
        step_size=0.001 / L_train.shape[0],
        reg_param=0)
    logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy))

    L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    pickle_file = 'tests/e2e/data/parts_by_doc_dict.pkl'
    with open(pickle_file, 'rb') as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float('nan')
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float('nan')
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float('nan')

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.4

    def LF_test_condition_aligned(c):
        return -1 if overlap(['test', 'condition'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_collector_aligned(c):
        return -1 if overlap([
            'collector', 'collector-current', 'collector-base',
            'collector-emitter'
        ], list(get_aligned_ngrams(c.attr))) else 0

    def LF_current_aligned(c):
        return -1 if overlap(['current', 'dc', 'ic'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_voltage_row_temp(c):
        return -1 if overlap(['voltage', 'cbo', 'ceo', 'ebo', 'v'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_voltage_row_part(c):
        return -1 if overlap(['voltage', 'cbo', 'ceo', 'ebo', 'v'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_typ_row(c):
        return -1 if overlap(['typ', 'typ.'],
                             list(get_row_ngrams(c.attr))) else 0

    def LF_complement_left_row(c):
        return -1 if (overlap(['complement', 'complementary'],
                              chain.from_iterable([
                                  get_row_ngrams(c.part),
                                  get_left_ngrams(c.part, window=10)
                              ]))) else 0

    def LF_too_many_numbers_row(c):
        num_numbers = list(get_row_ngrams(c.attr,
                                          attrib="ner_tags")).count('number')
        return -1 if num_numbers >= 3 else 0

    def LF_temp_on_high_page_num(c):
        return -1 if c.attr.get_attrib_tokens('page')[0] > 2 else 0

    def LF_temp_outside_table(c):
        return -1 if not c.attr.sentence.is_tabular() is None else 0

    def LF_not_temp_relevant(c):
        return -1 if not overlap(
            ['storage', 'temperature', 'tstg', 'stg', 'ts'],
            list(get_aligned_ngrams(c.attr))) else 0

    stg_temp_lfs_2 = [
        LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned,
        LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row,
        LF_complement_left_row, LF_too_many_numbers_row,
        LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs_2)
    L_train = labeler.apply(
        split=0,
        clear=False,
        update_keys=True,
        update_values=True,
        parallelism=PARALLEL)
    gen_model = GenerativeModel()
    gen_model.train(
        L_train,
        epochs=500,
        decay=0.9,
        step_size=0.001 / L_train.shape[0],
        reg_param=0)
    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float('nan')
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float('nan')
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float('nan')

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
示例#3
0
def test_e2e(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parseing
    if os.name == "posix":
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 12

    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    Part_Attr = candidate_subclass("Part_Attr", ["part", "attr"])

    docs_path = "tests/e2e/data/html/"
    pdf_path = "tests/e2e/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info("Sentences: {}".format(num_sentences))

    # Divide into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)
    assert len(docs[0].sentences) == 828
    assert len(docs[1].sentences) == 706
    assert len(docs[2].sentences) == 819
    assert len(docs[3].sentences) == 684
    assert len(docs[4].sentences) == 552
    assert len(docs[5].sentences) == 758
    assert len(docs[6].sentences) == 597
    assert len(docs[7].sentences) == 165
    assert len(docs[8].sentences) == 250
    assert len(docs[9].sentences) == 533
    assert len(docs[10].sentences) == 354
    assert len(docs[11].sentences) == 547

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    attr_matcher = RegexMatchSpan(rgx=r"(?:[1][5-9]|20)[05]",
                                  longest_match_only=False)

    ### Transistor Naming Conventions as Regular Expressions ###
    eeca_rgx = r"([ABC][A-Z][WXYZ]?[0-9]{3,5}(?:[A-Z]){0,5}[0-9]?[A-Z]?(?:-[A-Z0-9]{1,7})?(?:[-][A-Z0-9]{1,2})?(?:\/DG)?)"
    jedec_rgx = r"(2N\d{3,4}[A-Z]{0,5}[0-9]?[A-Z]?)"
    jis_rgx = r"(2S[ABCDEFGHJKMQRSTVZ]{1}[\d]{2,4})"
    others_rgx = r"((?:NSVBC|SMBT|MJ|MJE|MPS|MRF|RCA|TIP|ZTX|ZT|ZXT|TIS|TIPL|DTC|MMBT|SMMBT|PZT|FZT|STD|BUV|PBSS|KSC|CXT|FCX|CMPT){1}[\d]{2,4}[A-Z]{0,5}(?:-[A-Z0-9]{0,6})?(?:[-][A-Z0-9]{0,1})?)"

    part_rgx = "|".join([eeca_rgx, jedec_rgx, jis_rgx, others_rgx])
    part_rgx_matcher = RegexMatchSpan(rgx=part_rgx, longest_match_only=True)

    def get_digikey_parts_set(path):
        """
        Reads in the digikey part dictionary and yeilds each part.
        """
        all_parts = set()
        with open(path, "r") as csvinput:
            reader = csv.reader(csvinput)
            for line in reader:
                (part, url) = line
                all_parts.add(part)
        return all_parts

    ### Dictionary of known transistor parts ###
    dict_path = "tests/e2e/data/digikey_part_dictionary.csv"
    part_dict_matcher = DictionaryMatch(d=get_digikey_parts_set(dict_path))

    def common_prefix_length_diff(str1, str2):
        for i in range(min(len(str1), len(str2))):
            if str1[i] != str2[i]:
                return min(len(str1), len(str2)) - i
        return 0

    def part_file_name_conditions(attr):
        file_name = attr.sentence.document.name
        if len(file_name.split("_")) != 2:
            return False
        if attr.get_span()[0] == "-":
            return False
        name = attr.get_span().replace("-", "")
        return (any(char.isdigit() for char in name)
                and any(char.isalpha() for char in name) and
                common_prefix_length_diff(file_name.split("_")[1], name) <= 2)

    add_rgx = "^[A-Z0-9\-]{5,15}$"

    part_file_name_lambda_matcher = LambdaFunctionMatcher(
        func=part_file_name_conditions)
    part_file_name_matcher = Intersect(
        RegexMatchSpan(rgx=add_rgx, longest_match_only=True),
        part_file_name_lambda_matcher,
    )

    part_matcher = Union(part_rgx_matcher, part_dict_matcher,
                         part_file_name_matcher)

    part_ngrams = OmniNgramsPart(parts_by_doc=None, n_max=3)
    attr_ngrams = OmniNgramsTemp(n_max=2)

    def stg_temp_filter(c):
        (part, attr) = c
        if same_table((part, attr)):
            return is_horz_aligned((part, attr)) or is_vert_aligned(
                (part, attr))
        return True

    candidate_filter = stg_temp_filter

    candidate_extractor = CandidateExtractor(
        Part_Attr,
        [part_ngrams, attr_ngrams],
        [part_matcher, attr_matcher],
        candidate_filter=candidate_filter,
    )

    candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL)

    train_cands = session.query(Part_Attr).filter(Part_Attr.split == 0).all()
    logger.info("Number of candidates: {}".format(len(train_cands)))

    for i, docs in enumerate([dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i + 1)
        logger.info("Number of candidates: {}".format(
            session.query(Part_Attr).filter(Part_Attr.split == i + 1).count()))

    featurizer = BatchFeatureAnnotator(Part_Attr)
    F_train = featurizer.apply(split=0,
                               replace_key_set=True,
                               parallelism=PARALLEL)
    logger.info(F_train.shape)
    F_dev = featurizer.apply(split=1,
                             replace_key_set=False,
                             parallelism=PARALLEL)
    logger.info(F_dev.shape)
    F_test = featurizer.apply(split=2,
                              replace_key_set=False,
                              parallelism=PARALLEL)
    logger.info(F_test.shape)

    gold_file = "tests/e2e/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session,
                         Part_Attr,
                         gold_file,
                         ATTRIBUTE,
                         annotator_name="gold")

    def LF_storage_row(c):
        return 1 if "storage" in get_row_ngrams(c.attr) else 0

    def LF_temperature_row(c):
        return 1 if "temperature" in get_row_ngrams(c.attr) else 0

    def LF_operating_row(c):
        return 1 if "operating" in get_row_ngrams(c.attr) else 0

    def LF_tstg_row(c):
        return 1 if overlap(["tstg", "stg", "ts"], list(get_row_ngrams(
            c.attr))) else 0

    def LF_to_left(c):
        return 1 if "to" in get_left_ngrams(c.attr, window=2) else 0

    def LF_negative_number_left(c):
        return (1 if any([
            re.match(r"-\s*\d+", ngram)
            for ngram in get_left_ngrams(c.attr, window=4)
        ]) else 0)

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs)
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
    logger.info(L_train.shape)

    L_gold_train = load_gold_labels(session, annotator_name="gold", split=0)

    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy))

    L_gold_dev = load_gold_labels(session, annotator_name="gold", split=1)

    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    L_gold_test = load_gold_labels(session, annotator_name="gold", split=2)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    pickle_file = "tests/e2e/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.3

    def LF_test_condition_aligned(c):
        return (-1 if overlap(["test", "condition"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_collector_aligned(c):
        return (-1 if overlap(
            [
                "collector",
                "collector-current",
                "collector-base",
                "collector-emitter",
            ],
            list(get_aligned_ngrams(c.attr)),
        ) else 0)

    def LF_current_aligned(c):
        return (-1 if overlap(["current", "dc", "ic"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_voltage_row_temp(c):
        return (-1 if overlap(["voltage", "cbo", "ceo", "ebo", "v"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_voltage_row_part(c):
        return (-1 if overlap(["voltage", "cbo", "ceo", "ebo", "v"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_typ_row(c):
        return -1 if overlap(["typ", "typ."], list(get_row_ngrams(
            c.attr))) else 0

    def LF_complement_left_row(c):
        return (-1 if (overlap(
            ["complement", "complementary"],
            chain.from_iterable(
                [get_row_ngrams(c.part),
                 get_left_ngrams(c.part, window=10)]),
        )) else 0)

    def LF_too_many_numbers_row(c):
        num_numbers = list(get_row_ngrams(c.attr,
                                          attrib="ner_tags")).count("number")
        return -1 if num_numbers >= 3 else 0

    def LF_temp_on_high_page_num(c):
        return -1 if c.attr.get_attrib_tokens("page")[0] > 2 else 0

    def LF_temp_outside_table(c):
        return -1 if not c.attr.sentence.is_tabular() is None else 0

    def LF_not_temp_relevant(c):
        return (-1 if not overlap(
            ["storage", "temperature", "tstg", "stg", "ts"],
            list(get_aligned_ngrams(c.attr)),
        ) else 0)

    stg_temp_lfs_2 = [
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs_2)
    L_train = labeler.apply(split=0,
                            clear=False,
                            update_keys=True,
                            update_values=True,
                            parallelism=PARALLEL)
    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
示例#4
0
]
test_score = disc_model.predictions(F_test)
true_pred = [test_candidates[_] for _ in np.nditer(np.where(test_score > 0))]
train_score = disc_model.predictions(F_train)

# load gold label
from tutorials.organic_synthesis_figures.organic_utils import load_organic_labels
from fonduer import load_gold_labels

load_organic_labels(session,
                    Org_Fig,
                    gold_file,
                    ATTRIBUTE,
                    annotator_name='gold')

L_gold_train = load_gold_labels(session, annotator_name="gold", split=0)
print(L_train.lf_stats(L_gold_train))

L_gold_test = load_gold_labels(session, annotator_name="gold", split=1)
print(L_test.lf_stats(L_gold_test))

prec, rec, f1 = gen_model.score(L_test, L_gold_test)

from organic_utils import entity_level_f1

(TP, FP, FN) = entity_level_f1(true_pred, gold_file, ATTRIBUTE, test_docs)

from matplotlib import pyplot as plt


def plot_tp_entity(e):
示例#5
0
# In[24]:

L_train.get_candidate(session, 0)

# We can also view statistics about the resulting label matrix.
# * **Coverage** is the fraction of candidates that the labeling function emits a non-zero label for.
# * **Overlap** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a non-zero label for.
# * **Conflict** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a conflicting non-zero label for.
#
# In addition, because we have already loaded the gold labels, we can view the emperical accuracy of these labeling functions when compared to our gold labels:

# In[25]:

from fonduer import load_gold_labels
L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)
L_train.lf_stats(L_gold_train)

# ### Fitting the Generative Model
#
# Now, we'll train a model of the LFs to estimate their accuracies. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor. Intuitively, we'll model the LFs by observing how they overlap and conflict with each other.

# In[26]:

from fonduer import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(L_train,
                epochs=500,
                decay=0.9,
                step_size=0.001 / L_train.shape[0],