Ejemplo n.º 1
0
def label_cand():

    labeler = BatchLabelAnnotator(Org_Fig, lfs=org_fig_lfs)

    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)

    print(L_train.shape)
    return L_train
Ejemplo n.º 2
0
def train_generative():
    gen_model = GenerativeModel()
    labeler = BatchLabelAnnotator(Org_Fig, lfs=org_fig_lfs)
    L_train = labeler.load_matrix(split=0)
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    train_marginals = gen_model.marginals(L_train)
    return train_marginals
Ejemplo n.º 3
0
    LF_text_desc_match,
    LF_ocr_text_match,
    LF_text_length_match,
    LF_match_whitelist,
    LF_match_blacklist,
    LF_match_page,
    LF_pos_near,
    LF_organic_compound,
    LF_synthesis_of,
    LF_product_of,
    LF_first_period,
]

from fonduer import BatchLabelAnnotator

labeler = BatchLabelAnnotator(Org_Fig, lfs=org_fig_lfs)
L_train = labeler.load_matrix(split=0)

# L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
print(L_train.shape)

L_train.get_candidate(session, 0)

from fonduer import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(L_train,
                epochs=500,
                decay=0.9,
                step_size=0.001 / L_train.shape[0],
                reg_param=0)
Ejemplo n.º 4
0
Org_Fig = candidate_subclass('Org_Fig', ['product','figure'])

from fonduer import HTMLPreprocessor, OmniParser

docs_path = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/html/'
pdf_path = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/pdf/'



# load gold label
from tutorials.organic_synthesis_figures.organic_utils import load_organic_labels

gold_file = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/organic_gold.csv'
load_organic_labels(session, Org_Fig, gold_file, ATTRIBUTE ,annotator_name='gold')

from tutorials.organic_synthesis_figures import organic_lfs


from fonduer import BatchLabelAnnotator
#
labeler = BatchLabelAnnotator(Org_Fig, lfs = organic_lfs.org_fig_lfs)
L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
# print(L_train.shape)
#
# L_train.get_candidate(session, 0)

# # Applying the Labeling Functions
from fonduer import load_gold_labels
L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)
L_train.lf_stats(L_gold_train)
Ejemplo n.º 5
0
def test_e2e(caplog):
    """Run an end-to-end test on 20 documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    PARALLEL = 2
    max_docs = 12

    session = SnorkelSession()

    Part_Attr = candidate_subclass('Part_Attr', ['part', 'attr'])

    docs_path = 'tests/e2e/data/html/'
    pdf_path = 'tests/e2e/data/pdf/'

    doc_preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = OmniParser(
        structural=True, lingual=True, visual=True, pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_phrases = session.query(Phrase).count()
    logger.info("Phrases: {}".format(num_phrases))
    #  assert num_phrases == 20

    # Divid into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    attr_matcher = RegexMatchSpan(
        rgx=r'(?:[1][5-9]|20)[05]', longest_match_only=False)

    ### Transistor Naming Conventions as Regular Expressions ###
    eeca_rgx = r'([ABC][A-Z][WXYZ]?[0-9]{3,5}(?:[A-Z]){0,5}[0-9]?[A-Z]?(?:-[A-Z0-9]{1,7})?(?:[-][A-Z0-9]{1,2})?(?:\/DG)?)'
    jedec_rgx = r'(2N\d{3,4}[A-Z]{0,5}[0-9]?[A-Z]?)'
    jis_rgx = r'(2S[ABCDEFGHJKMQRSTVZ]{1}[\d]{2,4})'
    others_rgx = r'((?:NSVBC|SMBT|MJ|MJE|MPS|MRF|RCA|TIP|ZTX|ZT|ZXT|TIS|TIPL|DTC|MMBT|SMMBT|PZT|FZT|STD|BUV|PBSS|KSC|CXT|FCX|CMPT){1}[\d]{2,4}[A-Z]{0,5}(?:-[A-Z0-9]{0,6})?(?:[-][A-Z0-9]{0,1})?)'

    part_rgx = '|'.join([eeca_rgx, jedec_rgx, jis_rgx, others_rgx])
    part_rgx_matcher = RegexMatchSpan(rgx=part_rgx, longest_match_only=True)

    def get_digikey_parts_set(path):
        """
        Reads in the digikey part dictionary and yeilds each part.
        """
        all_parts = set()
        with open(path, "r") as csvinput:
            reader = csv.reader(csvinput)
            for line in reader:
                (part, url) = line
                all_parts.add(part)
        return all_parts

    ### Dictionary of known transistor parts ###
    dict_path = 'tests/e2e/data/digikey_part_dictionary.csv'
    part_dict_matcher = DictionaryMatch(d=get_digikey_parts_set(dict_path))

    def common_prefix_length_diff(str1, str2):
        for i in range(min(len(str1), len(str2))):
            if str1[i] != str2[i]:
                return min(len(str1), len(str2)) - i
        return 0

    def part_file_name_conditions(attr):
        file_name = attr.sentence.document.name
        if len(file_name.split('_')) != 2: return False
        if attr.get_span()[0] == '-': return False
        name = attr.get_span().replace('-', '')
        return any(char.isdigit() for char in name) and any(
            char.isalpha() for char in name) and common_prefix_length_diff(
                file_name.split('_')[1], name) <= 2

    add_rgx = '^[A-Z0-9\-]{5,15}$'

    part_file_name_lambda_matcher = LambdaFunctionMatcher(
        func=part_file_name_conditions)
    part_file_name_matcher = Intersect(
        RegexMatchSpan(rgx=add_rgx, longest_match_only=True),
        part_file_name_lambda_matcher)

    part_matcher = Union(part_rgx_matcher, part_dict_matcher,
                         part_file_name_matcher)

    part_ngrams = OmniNgramsPart(parts_by_doc=None, n_max=3)
    attr_ngrams = OmniNgramsTemp(n_max=2)

    def stg_temp_filter(c):
        (part, attr) = c
        if same_table((part, attr)):
            return (is_horz_aligned((part, attr)) or is_vert_aligned(
                (part, attr)))
        return True

    candidate_filter = stg_temp_filter

    candidate_extractor = CandidateExtractor(
        Part_Attr, [part_ngrams, attr_ngrams], [part_matcher, attr_matcher],
        candidate_filter=candidate_filter)

    candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL)

    train_cands = session.query(Part_Attr).filter(Part_Attr.split == 0).all()
    logger.info("Number of candidates: {}".format(len(train_cands)))

    for i, docs in enumerate([dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i + 1)
        logger.info("Number of candidates: {}".format(
            session.query(Part_Attr).filter(Part_Attr.split == i + 1).count()))

    featurizer = BatchFeatureAnnotator(Part_Attr)
    F_train = featurizer.apply(
        split=0, replace_key_set=True, parallelism=PARALLEL)
    logger.info(F_train.shape)
    F_dev = featurizer.apply(
        split=1, replace_key_set=False, parallelism=PARALLEL)
    logger.info(F_dev.shape)
    F_test = featurizer.apply(
        split=2, replace_key_set=False, parallelism=PARALLEL)
    logger.info(F_test.shape)

    gold_file = 'tests/e2e/data/hardware_tutorial_gold.csv'
    load_hardware_labels(
        session, Part_Attr, gold_file, ATTRIBUTE, annotator_name='gold')

    def LF_storage_row(c):
        return 1 if 'storage' in get_row_ngrams(c.attr) else 0

    def LF_temperature_row(c):
        return 1 if 'temperature' in get_row_ngrams(c.attr) else 0

    def LF_operating_row(c):
        return 1 if 'operating' in get_row_ngrams(c.attr) else 0

    def LF_tstg_row(c):
        return 1 if overlap(['tstg', 'stg', 'ts'], list(
            get_row_ngrams(c.attr))) else 0

    def LF_to_left(c):
        return 1 if 'to' in get_left_ngrams(c.attr, window=2) else 0

    def LF_negative_number_left(c):
        return 1 if any([
            re.match(r'-\s*\d+', ngram)
            for ngram in get_left_ngrams(c.attr, window=4)
        ]) else 0

    stg_temp_lfs = [
        LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row,
        LF_to_left, LF_negative_number_left
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs)
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
    logger.info(L_train.shape)

    L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)

    gen_model = GenerativeModel()
    gen_model.train(
        L_train,
        epochs=500,
        decay=0.9,
        step_size=0.001 / L_train.shape[0],
        reg_param=0)
    logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy))

    L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    pickle_file = 'tests/e2e/data/parts_by_doc_dict.pkl'
    with open(pickle_file, 'rb') as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float('nan')
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float('nan')
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float('nan')

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.4

    def LF_test_condition_aligned(c):
        return -1 if overlap(['test', 'condition'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_collector_aligned(c):
        return -1 if overlap([
            'collector', 'collector-current', 'collector-base',
            'collector-emitter'
        ], list(get_aligned_ngrams(c.attr))) else 0

    def LF_current_aligned(c):
        return -1 if overlap(['current', 'dc', 'ic'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_voltage_row_temp(c):
        return -1 if overlap(['voltage', 'cbo', 'ceo', 'ebo', 'v'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_voltage_row_part(c):
        return -1 if overlap(['voltage', 'cbo', 'ceo', 'ebo', 'v'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_typ_row(c):
        return -1 if overlap(['typ', 'typ.'],
                             list(get_row_ngrams(c.attr))) else 0

    def LF_complement_left_row(c):
        return -1 if (overlap(['complement', 'complementary'],
                              chain.from_iterable([
                                  get_row_ngrams(c.part),
                                  get_left_ngrams(c.part, window=10)
                              ]))) else 0

    def LF_too_many_numbers_row(c):
        num_numbers = list(get_row_ngrams(c.attr,
                                          attrib="ner_tags")).count('number')
        return -1 if num_numbers >= 3 else 0

    def LF_temp_on_high_page_num(c):
        return -1 if c.attr.get_attrib_tokens('page')[0] > 2 else 0

    def LF_temp_outside_table(c):
        return -1 if not c.attr.sentence.is_tabular() is None else 0

    def LF_not_temp_relevant(c):
        return -1 if not overlap(
            ['storage', 'temperature', 'tstg', 'stg', 'ts'],
            list(get_aligned_ngrams(c.attr))) else 0

    stg_temp_lfs_2 = [
        LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned,
        LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row,
        LF_complement_left_row, LF_too_many_numbers_row,
        LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs_2)
    L_train = labeler.apply(
        split=0,
        clear=False,
        update_keys=True,
        update_values=True,
        parallelism=PARALLEL)
    gen_model = GenerativeModel()
    gen_model.train(
        L_train,
        epochs=500,
        decay=0.9,
        step_size=0.001 / L_train.shape[0],
        reg_param=0)
    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float('nan')
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float('nan')
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float('nan')

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
Ejemplo n.º 6
0
from fonduer import candidate_subclass

Org_Fig = candidate_subclass('Org_Fig', ['product', 'figure'])

from fonduer import HTMLPreprocessor, OmniParser

docs_path = os.environ[
    'FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/html/'
pdf_path = os.environ[
    'FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/pdf/'

from fonduer import BatchLabelAnnotator
from tutorials.organic_synthesis_figures import organic_lfs

labeler = BatchLabelAnnotator(Org_Fig, lfs=organic_lfs.org_fig_lfs)
L_train = labeler.load_matrix(split=0)
print(L_train.shape)

L_train.get_candidate(session, 0)

# Applying the Labeling Functions
from fonduer import load_gold_labels
L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)
print(L_train.lf_stats(L_gold_train))

L_gold_test = load_gold_labels(session, annotator_name='gold', split=1)
'''

    Training
Ejemplo n.º 7
0
if with_image_feats:
    print('Merging image features')
    F_train = sparse.hstack(featurizer.load_matrix_and_image_features(
        split=0)).toarray()  # concatenate dense with sparse matrix
    F_test = sparse.hstack(featurizer.load_matrix_and_image_features(
        split=1)).toarray()  # concatenate dense with sparse matrixs
else:
    F_train = featurizer.load_matrix(split=0)
    F_test = featurizer.load_matrix(split=1)

from fonduer import BatchLabelAnnotator

from organic_lfs import *

labeler = BatchLabelAnnotator(Org_Fig, lfs=org_fig_lfs)

if relabel:
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
else:
    L_train = labeler.load_matrix(split=0)

print(L_train.shape)

L_train.get_candidate(session, 0)

L_test = labeler.apply_existing(split=1)

from fonduer import GenerativeModel

gen_model = GenerativeModel()
Ejemplo n.º 8
0
def test_e2e(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parseing
    if os.name == "posix":
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 12

    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    Part_Attr = candidate_subclass("Part_Attr", ["part", "attr"])

    docs_path = "tests/e2e/data/html/"
    pdf_path = "tests/e2e/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info("Sentences: {}".format(num_sentences))

    # Divide into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)
    assert len(docs[0].sentences) == 828
    assert len(docs[1].sentences) == 706
    assert len(docs[2].sentences) == 819
    assert len(docs[3].sentences) == 684
    assert len(docs[4].sentences) == 552
    assert len(docs[5].sentences) == 758
    assert len(docs[6].sentences) == 597
    assert len(docs[7].sentences) == 165
    assert len(docs[8].sentences) == 250
    assert len(docs[9].sentences) == 533
    assert len(docs[10].sentences) == 354
    assert len(docs[11].sentences) == 547

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    attr_matcher = RegexMatchSpan(rgx=r"(?:[1][5-9]|20)[05]",
                                  longest_match_only=False)

    ### Transistor Naming Conventions as Regular Expressions ###
    eeca_rgx = r"([ABC][A-Z][WXYZ]?[0-9]{3,5}(?:[A-Z]){0,5}[0-9]?[A-Z]?(?:-[A-Z0-9]{1,7})?(?:[-][A-Z0-9]{1,2})?(?:\/DG)?)"
    jedec_rgx = r"(2N\d{3,4}[A-Z]{0,5}[0-9]?[A-Z]?)"
    jis_rgx = r"(2S[ABCDEFGHJKMQRSTVZ]{1}[\d]{2,4})"
    others_rgx = r"((?:NSVBC|SMBT|MJ|MJE|MPS|MRF|RCA|TIP|ZTX|ZT|ZXT|TIS|TIPL|DTC|MMBT|SMMBT|PZT|FZT|STD|BUV|PBSS|KSC|CXT|FCX|CMPT){1}[\d]{2,4}[A-Z]{0,5}(?:-[A-Z0-9]{0,6})?(?:[-][A-Z0-9]{0,1})?)"

    part_rgx = "|".join([eeca_rgx, jedec_rgx, jis_rgx, others_rgx])
    part_rgx_matcher = RegexMatchSpan(rgx=part_rgx, longest_match_only=True)

    def get_digikey_parts_set(path):
        """
        Reads in the digikey part dictionary and yeilds each part.
        """
        all_parts = set()
        with open(path, "r") as csvinput:
            reader = csv.reader(csvinput)
            for line in reader:
                (part, url) = line
                all_parts.add(part)
        return all_parts

    ### Dictionary of known transistor parts ###
    dict_path = "tests/e2e/data/digikey_part_dictionary.csv"
    part_dict_matcher = DictionaryMatch(d=get_digikey_parts_set(dict_path))

    def common_prefix_length_diff(str1, str2):
        for i in range(min(len(str1), len(str2))):
            if str1[i] != str2[i]:
                return min(len(str1), len(str2)) - i
        return 0

    def part_file_name_conditions(attr):
        file_name = attr.sentence.document.name
        if len(file_name.split("_")) != 2:
            return False
        if attr.get_span()[0] == "-":
            return False
        name = attr.get_span().replace("-", "")
        return (any(char.isdigit() for char in name)
                and any(char.isalpha() for char in name) and
                common_prefix_length_diff(file_name.split("_")[1], name) <= 2)

    add_rgx = "^[A-Z0-9\-]{5,15}$"

    part_file_name_lambda_matcher = LambdaFunctionMatcher(
        func=part_file_name_conditions)
    part_file_name_matcher = Intersect(
        RegexMatchSpan(rgx=add_rgx, longest_match_only=True),
        part_file_name_lambda_matcher,
    )

    part_matcher = Union(part_rgx_matcher, part_dict_matcher,
                         part_file_name_matcher)

    part_ngrams = OmniNgramsPart(parts_by_doc=None, n_max=3)
    attr_ngrams = OmniNgramsTemp(n_max=2)

    def stg_temp_filter(c):
        (part, attr) = c
        if same_table((part, attr)):
            return is_horz_aligned((part, attr)) or is_vert_aligned(
                (part, attr))
        return True

    candidate_filter = stg_temp_filter

    candidate_extractor = CandidateExtractor(
        Part_Attr,
        [part_ngrams, attr_ngrams],
        [part_matcher, attr_matcher],
        candidate_filter=candidate_filter,
    )

    candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL)

    train_cands = session.query(Part_Attr).filter(Part_Attr.split == 0).all()
    logger.info("Number of candidates: {}".format(len(train_cands)))

    for i, docs in enumerate([dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i + 1)
        logger.info("Number of candidates: {}".format(
            session.query(Part_Attr).filter(Part_Attr.split == i + 1).count()))

    featurizer = BatchFeatureAnnotator(Part_Attr)
    F_train = featurizer.apply(split=0,
                               replace_key_set=True,
                               parallelism=PARALLEL)
    logger.info(F_train.shape)
    F_dev = featurizer.apply(split=1,
                             replace_key_set=False,
                             parallelism=PARALLEL)
    logger.info(F_dev.shape)
    F_test = featurizer.apply(split=2,
                              replace_key_set=False,
                              parallelism=PARALLEL)
    logger.info(F_test.shape)

    gold_file = "tests/e2e/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session,
                         Part_Attr,
                         gold_file,
                         ATTRIBUTE,
                         annotator_name="gold")

    def LF_storage_row(c):
        return 1 if "storage" in get_row_ngrams(c.attr) else 0

    def LF_temperature_row(c):
        return 1 if "temperature" in get_row_ngrams(c.attr) else 0

    def LF_operating_row(c):
        return 1 if "operating" in get_row_ngrams(c.attr) else 0

    def LF_tstg_row(c):
        return 1 if overlap(["tstg", "stg", "ts"], list(get_row_ngrams(
            c.attr))) else 0

    def LF_to_left(c):
        return 1 if "to" in get_left_ngrams(c.attr, window=2) else 0

    def LF_negative_number_left(c):
        return (1 if any([
            re.match(r"-\s*\d+", ngram)
            for ngram in get_left_ngrams(c.attr, window=4)
        ]) else 0)

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs)
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
    logger.info(L_train.shape)

    L_gold_train = load_gold_labels(session, annotator_name="gold", split=0)

    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy))

    L_gold_dev = load_gold_labels(session, annotator_name="gold", split=1)

    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    L_gold_test = load_gold_labels(session, annotator_name="gold", split=2)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    pickle_file = "tests/e2e/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.3

    def LF_test_condition_aligned(c):
        return (-1 if overlap(["test", "condition"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_collector_aligned(c):
        return (-1 if overlap(
            [
                "collector",
                "collector-current",
                "collector-base",
                "collector-emitter",
            ],
            list(get_aligned_ngrams(c.attr)),
        ) else 0)

    def LF_current_aligned(c):
        return (-1 if overlap(["current", "dc", "ic"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_voltage_row_temp(c):
        return (-1 if overlap(["voltage", "cbo", "ceo", "ebo", "v"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_voltage_row_part(c):
        return (-1 if overlap(["voltage", "cbo", "ceo", "ebo", "v"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_typ_row(c):
        return -1 if overlap(["typ", "typ."], list(get_row_ngrams(
            c.attr))) else 0

    def LF_complement_left_row(c):
        return (-1 if (overlap(
            ["complement", "complementary"],
            chain.from_iterable(
                [get_row_ngrams(c.part),
                 get_left_ngrams(c.part, window=10)]),
        )) else 0)

    def LF_too_many_numbers_row(c):
        num_numbers = list(get_row_ngrams(c.attr,
                                          attrib="ner_tags")).count("number")
        return -1 if num_numbers >= 3 else 0

    def LF_temp_on_high_page_num(c):
        return -1 if c.attr.get_attrib_tokens("page")[0] > 2 else 0

    def LF_temp_outside_table(c):
        return -1 if not c.attr.sentence.is_tabular() is None else 0

    def LF_not_temp_relevant(c):
        return (-1 if not overlap(
            ["storage", "temperature", "tstg", "stg", "ts"],
            list(get_aligned_ngrams(c.attr)),
        ) else 0)

    stg_temp_lfs_2 = [
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs_2)
    L_train = labeler.apply(split=0,
                            clear=False,
                            update_keys=True,
                            update_values=True,
                            parallelism=PARALLEL)
    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
Ejemplo n.º 9
0
# By default, `labeler.apply` will drop the existing table of labeling functions and the label values for each candidate. However, this behavior can be controlled by three parameters to the function to imperove iteration performance and reduce redundant computation:
# - `split` defines which set to operate on (e.g. train, dev, or test)
# - `clear` can be `True` or `False`, and is `True` by default. When set to `False`, the labeling functioni table is not dropped, and the behavior of `labeler.apply` is defined by the following two parameters.
# - `update_keys` can be `True` or `False`. When `True`, the keys (which are each labeling function) are updated according to the set of labeling functions provided to the function. This should be set to `True` if new labeling functions are added. When `False`, no new LFs are evaluated and the keys of existing LFs remain the same.
# - `update_values` can be `True` or `False`. This defines how to resolve conflicts. When `True`, the values assigned to each candiate is updated to the new values when in conflict. This should be set to `True` if labeling function logic is edited, even though the name of the labeling function remains the same. When `False`, the existing labels assigned to each candidate are used, and newly computed labels are ignored.
# - `parallelism` is the amount of parallelism to use when labeling.
#
# With this in mind, we set `clear=True` when we first apply our labeling functions, and this ensures that the table is created and intialized with proper keys and values.
#
# In future iterations, we would typically set `clear=False, update_keys=True, update_values=True` so that we can simply update the set of LFs and their values without recreating the entire table. We will see how this is used later in the tutorial.

# In[23]:

from fonduer import BatchLabelAnnotator

labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs)
L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
print(L_train.shape)

# Note that the returned matrix is a special subclass of the scipy.sparse.csr_matrix class, with some special features which we demonstrate below:

# In[24]:

L_train.get_candidate(session, 0)

# We can also view statistics about the resulting label matrix.
# * **Coverage** is the fraction of candidates that the labeling function emits a non-zero label for.
# * **Overlap** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a non-zero label for.
# * **Conflict** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a conflicting non-zero label for.
#
# In addition, because we have already loaded the gold labels, we can view the emperical accuracy of these labeling functions when compared to our gold labels: