예제 #1
0
PARALLEL = 1  # assuming a quad-core machine
ATTRIBUTE = "organic_figure"

os.environ['FONDUERHOME'] = '/Users/liwei/BoxSync/s2016/Dropbox/839_fonduer'
os.environ['FONDUERDBNAME'] = ATTRIBUTE
os.environ[
    'SNORKELDB'] = 'postgres://localhost:5432/' + os.environ['FONDUERDBNAME']

from fonduer import SnorkelSession

session = SnorkelSession()

from fonduer import candidate_subclass

Org_Fig = candidate_subclass('Org_Fig', ['organic', 'figure'])

docs_path = os.environ[
    'FONDUERHOME'] + '/tutorials/organic_synthesis_figures/data/html/'
pdf_path = os.environ[
    'FONDUERHOME'] + '/tutorials/organic_synthesis_figures/data/pdf/'

from fonduer import BatchFeatureAnnotator
from fonduer.features.features import get_organic_image_feats
from fonduer.features.read_images import gen_image_features

#Only need to do this once
gen_image_features(docs_path=docs_path)

featurizer = BatchFeatureAnnotator(Org_Fig, f=get_organic_image_feats)
print('merging results')
예제 #2
0
# os.system('bash ./set_env.sh')
PARALLEL = 1 # assuming a quad-core machine
ATTRIBUTE = "organic_figure"

os.environ['FONDUERHOME'] = '/home/xiuyuan/private/839/fonduer_new/839_fonduer/'
os.environ['FONDUERDBNAME'] = ATTRIBUTE
os.environ['SNORKELDB'] = 'postgres://*****:*****@localhost:5432/' + os.environ['FONDUERDBNAME']


from fonduer import SnorkelSession

session = SnorkelSession()

from fonduer import candidate_subclass

Org_Fig = candidate_subclass('Org_Fig', ['product','figure'])

from fonduer import HTMLPreprocessor, OmniParser

docs_path = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/html/'
pdf_path = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/pdf/'



# load gold label
from tutorials.organic_synthesis_figures.organic_utils import load_organic_labels

gold_file = os.environ['FONDUERHOME'] + 'tutorials/organic_synthesis_figures/data/organic_gold.csv'
load_organic_labels(session, Org_Fig, gold_file, ATTRIBUTE ,annotator_name='gold')

from tutorials.organic_synthesis_figures import organic_lfs
예제 #3
0
def test_e2e(caplog):
    """Run an end-to-end test on documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    # SpaCy on mac has issue on parallel parseing
    if os.name == "posix":
        PARALLEL = 1
    else:
        PARALLEL = 2  # Travis only gives 2 cores

    max_docs = 12

    session = Meta.init("postgres://localhost:5432/" + DB).Session()

    Part_Attr = candidate_subclass("Part_Attr", ["part", "attr"])

    docs_path = "tests/e2e/data/html/"
    pdf_path = "tests/e2e/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(structural=True,
                           lingual=True,
                           visual=True,
                           pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info("Sentences: {}".format(num_sentences))

    # Divide into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)
    assert len(docs[0].sentences) == 828
    assert len(docs[1].sentences) == 706
    assert len(docs[2].sentences) == 819
    assert len(docs[3].sentences) == 684
    assert len(docs[4].sentences) == 552
    assert len(docs[5].sentences) == 758
    assert len(docs[6].sentences) == 597
    assert len(docs[7].sentences) == 165
    assert len(docs[8].sentences) == 250
    assert len(docs[9].sentences) == 533
    assert len(docs[10].sentences) == 354
    assert len(docs[11].sentences) == 547

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    attr_matcher = RegexMatchSpan(rgx=r"(?:[1][5-9]|20)[05]",
                                  longest_match_only=False)

    ### Transistor Naming Conventions as Regular Expressions ###
    eeca_rgx = r"([ABC][A-Z][WXYZ]?[0-9]{3,5}(?:[A-Z]){0,5}[0-9]?[A-Z]?(?:-[A-Z0-9]{1,7})?(?:[-][A-Z0-9]{1,2})?(?:\/DG)?)"
    jedec_rgx = r"(2N\d{3,4}[A-Z]{0,5}[0-9]?[A-Z]?)"
    jis_rgx = r"(2S[ABCDEFGHJKMQRSTVZ]{1}[\d]{2,4})"
    others_rgx = r"((?:NSVBC|SMBT|MJ|MJE|MPS|MRF|RCA|TIP|ZTX|ZT|ZXT|TIS|TIPL|DTC|MMBT|SMMBT|PZT|FZT|STD|BUV|PBSS|KSC|CXT|FCX|CMPT){1}[\d]{2,4}[A-Z]{0,5}(?:-[A-Z0-9]{0,6})?(?:[-][A-Z0-9]{0,1})?)"

    part_rgx = "|".join([eeca_rgx, jedec_rgx, jis_rgx, others_rgx])
    part_rgx_matcher = RegexMatchSpan(rgx=part_rgx, longest_match_only=True)

    def get_digikey_parts_set(path):
        """
        Reads in the digikey part dictionary and yeilds each part.
        """
        all_parts = set()
        with open(path, "r") as csvinput:
            reader = csv.reader(csvinput)
            for line in reader:
                (part, url) = line
                all_parts.add(part)
        return all_parts

    ### Dictionary of known transistor parts ###
    dict_path = "tests/e2e/data/digikey_part_dictionary.csv"
    part_dict_matcher = DictionaryMatch(d=get_digikey_parts_set(dict_path))

    def common_prefix_length_diff(str1, str2):
        for i in range(min(len(str1), len(str2))):
            if str1[i] != str2[i]:
                return min(len(str1), len(str2)) - i
        return 0

    def part_file_name_conditions(attr):
        file_name = attr.sentence.document.name
        if len(file_name.split("_")) != 2:
            return False
        if attr.get_span()[0] == "-":
            return False
        name = attr.get_span().replace("-", "")
        return (any(char.isdigit() for char in name)
                and any(char.isalpha() for char in name) and
                common_prefix_length_diff(file_name.split("_")[1], name) <= 2)

    add_rgx = "^[A-Z0-9\-]{5,15}$"

    part_file_name_lambda_matcher = LambdaFunctionMatcher(
        func=part_file_name_conditions)
    part_file_name_matcher = Intersect(
        RegexMatchSpan(rgx=add_rgx, longest_match_only=True),
        part_file_name_lambda_matcher,
    )

    part_matcher = Union(part_rgx_matcher, part_dict_matcher,
                         part_file_name_matcher)

    part_ngrams = OmniNgramsPart(parts_by_doc=None, n_max=3)
    attr_ngrams = OmniNgramsTemp(n_max=2)

    def stg_temp_filter(c):
        (part, attr) = c
        if same_table((part, attr)):
            return is_horz_aligned((part, attr)) or is_vert_aligned(
                (part, attr))
        return True

    candidate_filter = stg_temp_filter

    candidate_extractor = CandidateExtractor(
        Part_Attr,
        [part_ngrams, attr_ngrams],
        [part_matcher, attr_matcher],
        candidate_filter=candidate_filter,
    )

    candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL)

    train_cands = session.query(Part_Attr).filter(Part_Attr.split == 0).all()
    logger.info("Number of candidates: {}".format(len(train_cands)))

    for i, docs in enumerate([dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i + 1)
        logger.info("Number of candidates: {}".format(
            session.query(Part_Attr).filter(Part_Attr.split == i + 1).count()))

    featurizer = BatchFeatureAnnotator(Part_Attr)
    F_train = featurizer.apply(split=0,
                               replace_key_set=True,
                               parallelism=PARALLEL)
    logger.info(F_train.shape)
    F_dev = featurizer.apply(split=1,
                             replace_key_set=False,
                             parallelism=PARALLEL)
    logger.info(F_dev.shape)
    F_test = featurizer.apply(split=2,
                              replace_key_set=False,
                              parallelism=PARALLEL)
    logger.info(F_test.shape)

    gold_file = "tests/e2e/data/hardware_tutorial_gold.csv"
    load_hardware_labels(session,
                         Part_Attr,
                         gold_file,
                         ATTRIBUTE,
                         annotator_name="gold")

    def LF_storage_row(c):
        return 1 if "storage" in get_row_ngrams(c.attr) else 0

    def LF_temperature_row(c):
        return 1 if "temperature" in get_row_ngrams(c.attr) else 0

    def LF_operating_row(c):
        return 1 if "operating" in get_row_ngrams(c.attr) else 0

    def LF_tstg_row(c):
        return 1 if overlap(["tstg", "stg", "ts"], list(get_row_ngrams(
            c.attr))) else 0

    def LF_to_left(c):
        return 1 if "to" in get_left_ngrams(c.attr, window=2) else 0

    def LF_negative_number_left(c):
        return (1 if any([
            re.match(r"-\s*\d+", ngram)
            for ngram in get_left_ngrams(c.attr, window=4)
        ]) else 0)

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs)
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
    logger.info(L_train.shape)

    L_gold_train = load_gold_labels(session, annotator_name="gold", split=0)

    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy))

    L_gold_dev = load_gold_labels(session, annotator_name="gold", split=1)

    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    L_gold_test = load_gold_labels(session, annotator_name="gold", split=2)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    pickle_file = "tests/e2e/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.3

    def LF_test_condition_aligned(c):
        return (-1 if overlap(["test", "condition"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_collector_aligned(c):
        return (-1 if overlap(
            [
                "collector",
                "collector-current",
                "collector-base",
                "collector-emitter",
            ],
            list(get_aligned_ngrams(c.attr)),
        ) else 0)

    def LF_current_aligned(c):
        return (-1 if overlap(["current", "dc", "ic"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_voltage_row_temp(c):
        return (-1 if overlap(["voltage", "cbo", "ceo", "ebo", "v"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_voltage_row_part(c):
        return (-1 if overlap(["voltage", "cbo", "ceo", "ebo", "v"],
                              list(get_aligned_ngrams(c.attr))) else 0)

    def LF_typ_row(c):
        return -1 if overlap(["typ", "typ."], list(get_row_ngrams(
            c.attr))) else 0

    def LF_complement_left_row(c):
        return (-1 if (overlap(
            ["complement", "complementary"],
            chain.from_iterable(
                [get_row_ngrams(c.part),
                 get_left_ngrams(c.part, window=10)]),
        )) else 0)

    def LF_too_many_numbers_row(c):
        num_numbers = list(get_row_ngrams(c.attr,
                                          attrib="ner_tags")).count("number")
        return -1 if num_numbers >= 3 else 0

    def LF_temp_on_high_page_num(c):
        return -1 if c.attr.get_attrib_tokens("page")[0] > 2 else 0

    def LF_temp_outside_table(c):
        return -1 if not c.attr.sentence.is_tabular() is None else 0

    def LF_not_temp_relevant(c):
        return (-1 if not overlap(
            ["storage", "temperature", "tstg", "stg", "ts"],
            list(get_aligned_ngrams(c.attr)),
        ) else 0)

    stg_temp_lfs_2 = [
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs_2)
    L_train = labeler.apply(split=0,
                            clear=False,
                            update_keys=True,
                            update_values=True,
                            parallelism=PARALLEL)
    gen_model = GenerativeModel()
    gen_model.train(L_train,
                    epochs=500,
                    decay=0.9,
                    step_size=0.001 / L_train.shape[0],
                    reg_param=0)
    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
예제 #4
0
def test_e2e(caplog):
    """Run an end-to-end test on 20 documents of the hardware domain."""
    caplog.set_level(logging.INFO)
    PARALLEL = 2
    max_docs = 12

    session = SnorkelSession()

    Part_Attr = candidate_subclass('Part_Attr', ['part', 'attr'])

    docs_path = 'tests/e2e/data/html/'
    pdf_path = 'tests/e2e/data/pdf/'

    doc_preprocessor = HTMLPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = OmniParser(
        structural=True, lingual=True, visual=True, pdf_path=pdf_path)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)

    num_docs = session.query(Document).count()
    logger.info("Docs: {}".format(num_docs))
    assert num_docs == max_docs

    num_phrases = session.query(Phrase).count()
    logger.info("Phrases: {}".format(num_phrases))
    #  assert num_phrases == 20

    # Divid into test and train
    docs = session.query(Document).order_by(Document.name).all()
    ld = len(docs)

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    attr_matcher = RegexMatchSpan(
        rgx=r'(?:[1][5-9]|20)[05]', longest_match_only=False)

    ### Transistor Naming Conventions as Regular Expressions ###
    eeca_rgx = r'([ABC][A-Z][WXYZ]?[0-9]{3,5}(?:[A-Z]){0,5}[0-9]?[A-Z]?(?:-[A-Z0-9]{1,7})?(?:[-][A-Z0-9]{1,2})?(?:\/DG)?)'
    jedec_rgx = r'(2N\d{3,4}[A-Z]{0,5}[0-9]?[A-Z]?)'
    jis_rgx = r'(2S[ABCDEFGHJKMQRSTVZ]{1}[\d]{2,4})'
    others_rgx = r'((?:NSVBC|SMBT|MJ|MJE|MPS|MRF|RCA|TIP|ZTX|ZT|ZXT|TIS|TIPL|DTC|MMBT|SMMBT|PZT|FZT|STD|BUV|PBSS|KSC|CXT|FCX|CMPT){1}[\d]{2,4}[A-Z]{0,5}(?:-[A-Z0-9]{0,6})?(?:[-][A-Z0-9]{0,1})?)'

    part_rgx = '|'.join([eeca_rgx, jedec_rgx, jis_rgx, others_rgx])
    part_rgx_matcher = RegexMatchSpan(rgx=part_rgx, longest_match_only=True)

    def get_digikey_parts_set(path):
        """
        Reads in the digikey part dictionary and yeilds each part.
        """
        all_parts = set()
        with open(path, "r") as csvinput:
            reader = csv.reader(csvinput)
            for line in reader:
                (part, url) = line
                all_parts.add(part)
        return all_parts

    ### Dictionary of known transistor parts ###
    dict_path = 'tests/e2e/data/digikey_part_dictionary.csv'
    part_dict_matcher = DictionaryMatch(d=get_digikey_parts_set(dict_path))

    def common_prefix_length_diff(str1, str2):
        for i in range(min(len(str1), len(str2))):
            if str1[i] != str2[i]:
                return min(len(str1), len(str2)) - i
        return 0

    def part_file_name_conditions(attr):
        file_name = attr.sentence.document.name
        if len(file_name.split('_')) != 2: return False
        if attr.get_span()[0] == '-': return False
        name = attr.get_span().replace('-', '')
        return any(char.isdigit() for char in name) and any(
            char.isalpha() for char in name) and common_prefix_length_diff(
                file_name.split('_')[1], name) <= 2

    add_rgx = '^[A-Z0-9\-]{5,15}$'

    part_file_name_lambda_matcher = LambdaFunctionMatcher(
        func=part_file_name_conditions)
    part_file_name_matcher = Intersect(
        RegexMatchSpan(rgx=add_rgx, longest_match_only=True),
        part_file_name_lambda_matcher)

    part_matcher = Union(part_rgx_matcher, part_dict_matcher,
                         part_file_name_matcher)

    part_ngrams = OmniNgramsPart(parts_by_doc=None, n_max=3)
    attr_ngrams = OmniNgramsTemp(n_max=2)

    def stg_temp_filter(c):
        (part, attr) = c
        if same_table((part, attr)):
            return (is_horz_aligned((part, attr)) or is_vert_aligned(
                (part, attr)))
        return True

    candidate_filter = stg_temp_filter

    candidate_extractor = CandidateExtractor(
        Part_Attr, [part_ngrams, attr_ngrams], [part_matcher, attr_matcher],
        candidate_filter=candidate_filter)

    candidate_extractor.apply(train_docs, split=0, parallelism=PARALLEL)

    train_cands = session.query(Part_Attr).filter(Part_Attr.split == 0).all()
    logger.info("Number of candidates: {}".format(len(train_cands)))

    for i, docs in enumerate([dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i + 1)
        logger.info("Number of candidates: {}".format(
            session.query(Part_Attr).filter(Part_Attr.split == i + 1).count()))

    featurizer = BatchFeatureAnnotator(Part_Attr)
    F_train = featurizer.apply(
        split=0, replace_key_set=True, parallelism=PARALLEL)
    logger.info(F_train.shape)
    F_dev = featurizer.apply(
        split=1, replace_key_set=False, parallelism=PARALLEL)
    logger.info(F_dev.shape)
    F_test = featurizer.apply(
        split=2, replace_key_set=False, parallelism=PARALLEL)
    logger.info(F_test.shape)

    gold_file = 'tests/e2e/data/hardware_tutorial_gold.csv'
    load_hardware_labels(
        session, Part_Attr, gold_file, ATTRIBUTE, annotator_name='gold')

    def LF_storage_row(c):
        return 1 if 'storage' in get_row_ngrams(c.attr) else 0

    def LF_temperature_row(c):
        return 1 if 'temperature' in get_row_ngrams(c.attr) else 0

    def LF_operating_row(c):
        return 1 if 'operating' in get_row_ngrams(c.attr) else 0

    def LF_tstg_row(c):
        return 1 if overlap(['tstg', 'stg', 'ts'], list(
            get_row_ngrams(c.attr))) else 0

    def LF_to_left(c):
        return 1 if 'to' in get_left_ngrams(c.attr, window=2) else 0

    def LF_negative_number_left(c):
        return 1 if any([
            re.match(r'-\s*\d+', ngram)
            for ngram in get_left_ngrams(c.attr, window=4)
        ]) else 0

    stg_temp_lfs = [
        LF_storage_row, LF_operating_row, LF_temperature_row, LF_tstg_row,
        LF_to_left, LF_negative_number_left
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs)
    L_train = labeler.apply(split=0, clear=True, parallelism=PARALLEL)
    logger.info(L_train.shape)

    L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)

    gen_model = GenerativeModel()
    gen_model.train(
        L_train,
        epochs=500,
        decay=0.9,
        step_size=0.001 / L_train.shape[0],
        reg_param=0)
    logger.info("LF Accuracy: {}".format(gen_model.weights.lf_accuracy))

    L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    pickle_file = 'tests/e2e/data/parts_by_doc_dict.pkl'
    with open(pickle_file, 'rb') as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float('nan')
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float('nan')
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float('nan')

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 < 0.7 and f1 > 0.4

    def LF_test_condition_aligned(c):
        return -1 if overlap(['test', 'condition'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_collector_aligned(c):
        return -1 if overlap([
            'collector', 'collector-current', 'collector-base',
            'collector-emitter'
        ], list(get_aligned_ngrams(c.attr))) else 0

    def LF_current_aligned(c):
        return -1 if overlap(['current', 'dc', 'ic'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_voltage_row_temp(c):
        return -1 if overlap(['voltage', 'cbo', 'ceo', 'ebo', 'v'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_voltage_row_part(c):
        return -1 if overlap(['voltage', 'cbo', 'ceo', 'ebo', 'v'],
                             list(get_aligned_ngrams(c.attr))) else 0

    def LF_typ_row(c):
        return -1 if overlap(['typ', 'typ.'],
                             list(get_row_ngrams(c.attr))) else 0

    def LF_complement_left_row(c):
        return -1 if (overlap(['complement', 'complementary'],
                              chain.from_iterable([
                                  get_row_ngrams(c.part),
                                  get_left_ngrams(c.part, window=10)
                              ]))) else 0

    def LF_too_many_numbers_row(c):
        num_numbers = list(get_row_ngrams(c.attr,
                                          attrib="ner_tags")).count('number')
        return -1 if num_numbers >= 3 else 0

    def LF_temp_on_high_page_num(c):
        return -1 if c.attr.get_attrib_tokens('page')[0] > 2 else 0

    def LF_temp_outside_table(c):
        return -1 if not c.attr.sentence.is_tabular() is None else 0

    def LF_not_temp_relevant(c):
        return -1 if not overlap(
            ['storage', 'temperature', 'tstg', 'stg', 'ts'],
            list(get_aligned_ngrams(c.attr))) else 0

    stg_temp_lfs_2 = [
        LF_test_condition_aligned, LF_collector_aligned, LF_current_aligned,
        LF_voltage_row_temp, LF_voltage_row_part, LF_typ_row,
        LF_complement_left_row, LF_too_many_numbers_row,
        LF_temp_on_high_page_num, LF_temp_outside_table, LF_not_temp_relevant
    ]

    labeler = BatchLabelAnnotator(Part_Attr, lfs=stg_temp_lfs_2)
    L_train = labeler.apply(
        split=0,
        clear=False,
        update_keys=True,
        update_values=True,
        parallelism=PARALLEL)
    gen_model = GenerativeModel()
    gen_model.train(
        L_train,
        epochs=500,
        decay=0.9,
        step_size=0.001 / L_train.shape[0],
        reg_param=0)
    train_marginals = gen_model.marginals(L_train)

    disc_model = SparseLogisticRegression()
    disc_model.train(F_train, train_marginals, n_epochs=200, lr=0.001)

    test_candidates = [
        F_test.get_candidate(session, i) for i in range(F_test.shape[0])
    ]
    test_score = disc_model.predictions(F_test)
    true_pred = [
        test_candidates[_] for _ in np.nditer(np.where(test_score > 0))
    ]

    (TP, FP, FN) = entity_level_f1(
        true_pred, gold_file, ATTRIBUTE, test_docs, parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float('nan')
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float('nan')
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float('nan')

    logger.info("prec: {}".format(prec))
    logger.info("rec: {}".format(rec))
    logger.info("f1: {}".format(f1))

    assert f1 > 0.7
예제 #5
0
#
# We first initialize a `SnorkelSession`, which manages the connection to the database automatically, and enables us to save intermediate results.

# In[2]:

from fonduer import SnorkelSession

session = SnorkelSession()

# Next, we define the _schema_ of the relation we want to extract. This must be a subclass of Candidate, and we define it using a helper function. Here, we define a binary relation which connects two Span objects of text. This is what creates the relation's database table if it does not already exist.

# In[3]:

from fonduer import candidate_subclass

Part_Attr = candidate_subclass('Part_Attr', ['part', 'attr'])

# ## 1.2 Parsing and Transforming the Input Documents into Unified Data Models
#
# Next, we load the corpus of datasheets and transform them into the unified data model. Each datasheet has a PDF and HTML representation. Both representations are used in conjunction to create a robust unified data model with textual, structural, tabular, and visual modality information. Note that since each document is independent of each other, we can parse the documents in parallel. Note that parallel execution will not work with SQLite, the default database engine. We depend on PostgreSQL for this functionality.
#
# ### Configuring an `HTMLPreprocessor`
# We start by setting the paths to where our documents are stored, and defining a `HTMLPreprocessor` to read in the documents found in the specified paths. `max_docs` specified the number of documents to parse. For the sake of this tutorial, we only look at 100 documents.
#
# **Note that you need to have run `download_data.sh` before executing these next steps or you won't have the documents needed for the tutorial.**

# In[4]:

from fonduer import HTMLPreprocessor, OmniParser

docs_path = os.environ['FONDUERHOME'] + '/tutorials/hardware/data/html/'