コード例 #1
0
ファイル: test_candidates.py プロジェクト: shazhiju/fonduer
def test_multimodal_cand(caplog):
    """Test multimodal candidate generation"""
    caplog.set_level(logging.INFO)

    PARALLEL = 4

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/pure_html/radiology.html"

    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(session, structural=True, lingual=True)
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs

    assert session.query(Sentence).count() == 35
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction

    ms_doc = mention_subclass("m_doc")
    ms_sec = mention_subclass("m_sec")
    ms_tab = mention_subclass("m_tab")
    ms_fig = mention_subclass("m_fig")
    ms_cell = mention_subclass("m_cell")
    ms_para = mention_subclass("m_para")
    ms_cap = mention_subclass("m_cap")
    ms_sent = mention_subclass("m_sent")

    m_doc = MentionDocuments()
    m_sec = MentionSections()
    m_tab = MentionTables()
    m_fig = MentionFigures()
    m_cell = MentionCells()
    m_para = MentionParagraphs()
    m_cap = MentionCaptions()
    m_sent = MentionSentences()

    ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell]
    m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell]
    matchers = [DoNothingMatcher()] * 8

    mention_extractor = MentionExtractor(session,
                                         ms,
                                         m,
                                         matchers,
                                         parallelism=PARALLEL)

    mention_extractor.apply(docs)

    assert session.query(ms_doc).count() == 1
    assert session.query(ms_cap).count() == 2
    assert session.query(ms_sec).count() == 5
    assert session.query(ms_tab).count() == 2
    assert session.query(ms_fig).count() == 2
    assert session.query(ms_para).count() == 30
    assert session.query(ms_sent).count() == 35
    assert session.query(ms_cell).count() == 21
コード例 #2
0
ファイル: test_candidates.py プロジェクト: gadgetlabs/fonduer
def test_multimodal_cand():
    """Test multimodal candidate generation"""
    file_name = "radiology"
    docs_path = f"tests/data/pure_html/{file_name}.html"
    doc = parse_doc(docs_path, file_name)

    assert len(doc.sentences) == 35

    # Mention Extraction

    ms_doc = mention_subclass("m_doc")
    ms_sec = mention_subclass("m_sec")
    ms_tab = mention_subclass("m_tab")
    ms_fig = mention_subclass("m_fig")
    ms_cell = mention_subclass("m_cell")
    ms_para = mention_subclass("m_para")
    ms_cap = mention_subclass("m_cap")
    ms_sent = mention_subclass("m_sent")

    m_doc = MentionDocuments()
    m_sec = MentionSections()
    m_tab = MentionTables()
    m_fig = MentionFigures()
    m_cell = MentionCells()
    m_para = MentionParagraphs()
    m_cap = MentionCaptions()
    m_sent = MentionSentences()

    ms = [ms_doc, ms_cap, ms_sec, ms_tab, ms_fig, ms_para, ms_sent, ms_cell]
    m = [m_doc, m_cap, m_sec, m_tab, m_fig, m_para, m_sent, m_cell]
    matchers = [DoNothingMatcher()] * 8

    mention_extractor_udf = MentionExtractorUDF(ms, m, matchers)

    doc = mention_extractor_udf.apply(doc)

    assert len(doc.m_docs) == 1
    assert len(doc.m_caps) == 2
    assert len(doc.m_secs) == 5
    assert len(doc.m_tabs) == 2
    assert len(doc.m_figs) == 2
    assert len(doc.m_paras) == 30
    assert len(doc.m_sents) == 35
    assert len(doc.m_cells) == 21
コード例 #3
0
    def apply(self, doc: Document) -> Iterator[TemporaryContext]:
        """Generate MentionNgrams from a Document by parsing all of its Sentences.

        :param doc: The ``Document`` to parse.
        :type doc: ``Document``
        :raises TypeError: If the input doc is not of type ``Document``.
        """
        if not isinstance(doc, Document):
            raise TypeError(
                "Input Contexts to MentionNgrams.apply() must be of type Document"
            )

        for ts in MentionNgrams.apply(self, doc):
            yield ts
            
        for ts in MentionDocuments.apply(self, doc):
            yield ts
            
        for ts in MentionFigures.apply(self, doc):
            yield ts
コード例 #4
0
ファイル: test_candidates.py プロジェクト: sbrown-ai/fonduer
def test_cand_gen(caplog):
    """Test extracting candidates from mentions from documents."""
    caplog.set_level(logging.INFO)

    if platform == "darwin":
        logger.info("Using single core.")
        PARALLEL = 1
    else:
        logger.info("Using two cores.")
        PARALLEL = 2  # Travis only gives 2 cores

    def do_nothing_matcher(fig):
        return True

    max_docs = 1
    session = Meta.init("postgresql://localhost:5432/" + DB).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    assert session.query(Document).count() == max_docs
    assert session.query(Sentence).count() == 799
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        mention_extractor = MentionExtractor(
            session,
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 70
    assert session.query(Volt).count() == 33
    assert session.query(Temp).count() == 23
    assert session.query(Fig).count() == 31
    part = session.query(Part).order_by(Part.id).all()[0]
    volt = session.query(Volt).order_by(Volt.id).all()[0]
    temp = session.query(Temp).order_by(Temp.id).all()[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Volt: {volt.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp, PartVolt],
            throttlers=[
                temp_throttler,
                volt_throttler,
                volt_throttler,
            ],  # Fail, mismatched arity
        )

    with pytest.raises(ValueError):
        candidate_extractor = CandidateExtractor(
            session,
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],
        )

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt]
    )  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1610
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3920
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0
    assert session.query(PartTemp).count() == 0
    assert session.query(PartVolt).count() == 0

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, None]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 2310
    assert session.query(Candidate).count() == 3742
    candidate_extractor.clear_all(split=0)
    assert session.query(Candidate).count() == 0

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, volt_throttler]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)

    assert session.query(PartTemp).count() == 1432
    assert session.query(PartVolt).count() == 1993
    assert session.query(Candidate).count() == 3425
    assert docs[0].name == "112823"
    assert len(docs[0].parts) == 70
    assert len(docs[0].volts) == 33
    assert len(docs[0].temps) == 23

    # Test that deletion of a Candidate does not delete the Mention
    session.query(PartTemp).delete(synchronize_session="fetch")
    assert session.query(PartTemp).count() == 0
    assert session.query(Temp).count() == 23
    assert session.query(Part).count() == 70

    # Test deletion of Candidate if Mention is deleted
    assert session.query(PartVolt).count() == 1993
    assert session.query(Volt).count() == 33
    session.query(Volt).delete(synchronize_session="fetch")
    assert session.query(Volt).count() == 0
    assert session.query(PartVolt).count() == 0
コード例 #5
0
def test_too_many_clients_error_should_not_happen():
    """Too many clients error should not happens."""
    PARALLEL = 32
    logger.info("Parallel: {PARALLEL}")

    def do_nothing_matcher(fig):
        return True

    max_docs = 1
    session = Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    # Parsing
    logger.info("Parsing...")
    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)
    corpus_parser = Parser(
        session, structural=True, lingual=True, visual=True, pdf_path=pdf_path
    )
    corpus_parser.apply(doc_preprocessor, parallelism=PARALLEL)
    docs = session.query(Document).order_by(Document.name).all()

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    mention_extractor.apply(docs, parallelism=PARALLEL)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    # Test that no throttler in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt]
    )  # Pass, no throttler

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
    candidate_extractor.clear_all(split=0)

    # Test with None in throttlers in candidate extractor
    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt], throttlers=[temp_throttler, None]
    )

    candidate_extractor.apply(docs, split=0, parallelism=PARALLEL)
コード例 #6
0
ファイル: circular_connectors.py プロジェクト: q5s2c1/TDJEE
def main(
    conn_string,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    tuner_config = {"max_search": 3}

    em_config = {
        # GENERAL
        "seed": None,
        "verbose": True,
        "show_plots": True,
        # Network
        # The first value is the output dim of the input module (or the sum of
        # the output dims of all the input modules if multitask=True and
        # multiple input modules are provided). The last value is the
        # output dim of the head layer (i.e., the cardinality of the
        # classification task). The remaining values are the output dims of
        # middle layers (if any). The number of middle layers will be inferred
        # from this list.
        #     "layer_out_dims": [10, 2],
        # Input layer configs
        "input_layer_config": {
            "input_relu": False,
            "input_batchnorm": False,
            "input_dropout": 0.0,
        },
        # Middle layer configs
        "middle_layer_config": {
            "middle_relu": False,
            "middle_batchnorm": False,
            "middle_dropout": 0.0,
        },
        # Can optionally skip the head layer completely, for e.g. running baseline
        # models...
        "skip_head": True,
        # GPU
        "device": "cpu",
        # MODEL CLASS
        "resnet18"
        # DATA CONFIG
        "src": "gm",
        # TRAINING
        "train_config": {
            # Display
            "print_every": 1,  # Print after this many epochs
            "disable_prog_bar": False,  # Disable progress bar each epoch
            # Dataloader
            "data_loader_config": {
                "batch_size": 32,
                "num_workers": 8,
                "sampler": None
            },
            # Loss weights
            "loss_weights": [0.5, 0.5],
            # Train Loop
            "n_epochs": 20,
            # 'grad_clip': 0.0,
            "l2": 0.0,
            # "lr": 0.01,
            "validation_metric": "accuracy",
            "validation_freq": 1,
            # Evaluate dev for during training every this many epochs
            # Optimizer
            "optimizer_config": {
                "optimizer": "adam",
                "optimizer_common": {
                    "lr": 0.01
                },
                # Optimizer - SGD
                "sgd_config": {
                    "momentum": 0.9
                },
                # Optimizer - Adam
                "adam_config": {
                    "betas": (0.9, 0.999)
                },
            },
            # Scheduler
            "scheduler_config": {
                "scheduler": "reduce_on_plateau",
                # ['constant', 'exponential', 'reduce_on_plateu']
                # Freeze learning rate initially this many epochs
                "lr_freeze": 0,
                # Scheduler - exponential
                "exponential_config": {
                    "gamma": 0.9
                },  # decay rate
                # Scheduler - reduce_on_plateau
                "plateau_config": {
                    "factor": 0.5,
                    "patience": 1,
                    "threshold": 0.0001,
                    "min_lr": 1e-5,
                },
            },
            # Checkpointer
            "checkpoint": True,
            "checkpoint_config": {
                "checkpoint_min": -1,
                # The initial best score to beat to merit checkpointing
                "checkpoint_runway": 0,
                # Don't start taking checkpoints until after this many epochs
            },
        },
    }

    session = Meta.init(conn_string).Session()

    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    logger.info(f"CWD: {os.getcwd()}")
    dirname = "."

    docs, train_docs, dev_docs, test_docs = parse_dataset(
        session,
        dirname,
        first_time=first_time,
        parallel=parallel,
        max_docs=max_docs)
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")

    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    Thumbnails = mention_subclass("Thumbnails")

    thumbnails_img = MentionFigures()

    class HasFigures(_Matcher):
        def _f(self, m):
            file_path = ""
            for prefix in [
                    "data/train/html/", "data/dev/html/", "data/test/html/"
            ]:
                if os.path.exists(prefix + m.figure.url):
                    file_path = prefix + m.figure.url
            if file_path == "":
                return False
            img = Image.open(file_path)
            width, height = img.size
            min_value = min(width, height)
            return min_value > 50

    mention_extractor = MentionExtractor(session, [Thumbnails],
                                         [thumbnails_img], [HasFigures()],
                                         parallelism=parallel)

    if first_time:
        mention_extractor.apply(docs)

    logger.info("Total Mentions: {}".format(session.query(Mention).count()))

    ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails])

    candidate_extractor = CandidateExtractor(session, [ThumbnailLabel],
                                             throttlers=[None],
                                             parallelism=parallel)

    if first_time:
        candidate_extractor.apply(train_docs, split=0)
        candidate_extractor.apply(dev_docs, split=1)
        candidate_extractor.apply(test_docs, split=2)

    train_cands = candidate_extractor.get_candidates(split=0)
    dev_cands = candidate_extractor.get_candidates(split=1)
    test_cands = candidate_extractor.get_candidates(split=2)

    logger.info("Total train candidate:\t{}".format(len(train_cands[0])))
    logger.info("Total dev candidate:\t{}".format(len(dev_cands[0])))
    logger.info("Total test candidate:\t{}".format(len(test_cands[0])))

    fin = open("data/ground_truth.txt", "r")
    gt = set()
    for line in fin:
        gt.add("::".join(line.lower().split()))
    fin.close()

    def LF_gt_label(c):
        doc_file_id = (f"{c[0].context.figure.document.name.lower()}.pdf::"
                       f"{os.path.basename(c[0].context.figure.url.lower())}")
        return TRUE if doc_file_id in gt else FALSE

    ans = {0: 0, 1: 0, 2: 0}

    gt_dev_pb = []
    gt_dev = []
    gt_test = []

    for cand in dev_cands[0]:
        if LF_gt_label(cand) == 1:
            ans[1] += 1
            gt_dev_pb.append([1.0, 0.0])
            gt_dev.append(1.0)
        else:
            ans[2] += 1
            gt_dev_pb.append([0.0, 1.0])
            gt_dev.append(2.0)

    ans = {0: 0, 1: 0, 2: 0}
    for cand in test_cands[0]:
        gt_test.append(LF_gt_label(cand))
        ans[gt_test[-1]] += 1

    batch_size = 64
    input_size = 224

    train_loader = torch.utils.data.DataLoader(
        ImageList(
            data=dev_cands[0],
            label=torch.Tensor(gt_dev_pb),
            transform=transform(input_size),
            prefix="data/dev/html/",
        ),
        batch_size=batch_size,
        shuffle=False,
    )

    dev_loader = torch.utils.data.DataLoader(
        ImageList(
            data=dev_cands[0],
            label=gt_dev,
            transform=transform(input_size),
            prefix="data/dev/html/",
        ),
        batch_size=batch_size,
        shuffle=False,
    )

    test_loader = torch.utils.data.DataLoader(
        ImageList(
            data=test_cands[0],
            label=gt_test,
            transform=transform(input_size),
            prefix="data/test/html/",
        ),
        batch_size=100,
        shuffle=False,
    )

    search_space = {
        "l2": [0.001, 0.0001, 0.00001],  # linear range
        "lr": {
            "range": [0.0001, 0.1],
            "scale": "log"
        },  # log range
    }

    train_config = em_config["train_config"]

    # Defining network parameters
    num_classes = 2
    #  fc_size = 2
    #  hidden_size = 2
    pretrained = True

    # Set CUDA device
    if gpu:
        em_config["device"] = "cuda"
        torch.cuda.set_device(int(gpu))

    # Initializing input module
    input_module = get_cnn("resnet18",
                           pretrained=pretrained,
                           num_classes=num_classes)

    # Initializing model object
    init_args = [[num_classes]]
    init_kwargs = {"input_module": input_module}
    init_kwargs.update(em_config)

    # Searching model
    log_config = {
        "log_dir": os.path.join(dirname, log_dir),
        "run_name": "image"
    }
    searcher = RandomSearchTuner(EndModel, **log_config)

    end_model = searcher.search(
        search_space,
        dev_loader,
        train_args=[train_loader],
        init_args=init_args,
        init_kwargs=init_kwargs,
        train_kwargs=train_config,
        max_search=tuner_config["max_search"],
    )

    # Evaluating model
    scores = end_model.score(
        test_loader,
        metric=["accuracy", "precision", "recall", "f1"],
        break_ties="abstain",
    )
    logger.warning("End Model Score:")
    logger.warning(f"precision: {scores[1]:.3f}")
    logger.warning(f"recall: {scores[2]:.3f}")
    logger.warning(f"f1: {scores[3]:.3f}")
コード例 #7
0
def test_cand_gen():
    """Test extracting candidates from mentions from documents."""

    def do_nothing_matcher(fig):
        return True

    docs_path = "tests/data/html/112823.html"
    pdf_path = "tests/data/pdf/112823.pdf"
    doc = parse_doc(docs_path, "112823", pdf_path)

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)
    figs = MentionFigures(types="png")

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")
    Fig = mention_subclass("Fig")

    fig_matcher = LambdaFunctionFigureMatcher(func=do_nothing_matcher)

    with pytest.raises(ValueError):
        MentionExtractor(
            "dummy",
            [Part, Temp, Volt],
            [part_ngrams, volt_ngrams],  # Fail, mismatched arity
            [part_matcher, temp_matcher, volt_matcher],
        )
    with pytest.raises(ValueError):
        MentionExtractor(
            "dummy",
            [Part, Temp, Volt],
            [part_ngrams, temp_matcher, volt_ngrams],
            [part_matcher, temp_matcher],  # Fail, mismatched arity
        )

    mention_extractor_udf = MentionExtractorUDF(
        [Part, Temp, Volt, Fig],
        [part_ngrams, temp_ngrams, volt_ngrams, figs],
        [part_matcher, temp_matcher, volt_matcher, fig_matcher],
    )
    doc = mention_extractor_udf.apply(doc)

    assert len(doc.parts) == 70
    assert len(doc.volts) == 33
    assert len(doc.temps) == 23
    assert len(doc.figs) == 31
    part = doc.parts[0]
    volt = doc.volts[0]
    temp = doc.temps[0]
    logger.info(f"Part: {part.context}")
    logger.info(f"Volt: {volt.context}")
    logger.info(f"Temp: {temp.context}")

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    with pytest.raises(ValueError):
        CandidateExtractor(
            "dummy",
            [PartTemp, PartVolt],
            throttlers=[
                temp_throttler,
                volt_throttler,
                volt_throttler,
            ],  # Fail, mismatched arity
        )

    with pytest.raises(ValueError):
        CandidateExtractor(
            "dummy",
            [PartTemp],  # Fail, mismatched arity
            throttlers=[temp_throttler, volt_throttler],
        )

    # Test that no throttler in candidate extractor
    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [None, None], False, False, True  # Pass, no throttler
    )

    doc = candidate_extractor_udf.apply(doc, split=0)

    assert len(doc.part_temps) == 1610
    assert len(doc.part_volts) == 2310

    # Clear
    doc.part_temps = []
    doc.part_volts = []

    # Test with None in throttlers in candidate extractor
    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [temp_throttler, None], False, False, True
    )

    doc = candidate_extractor_udf.apply(doc, split=0)
    assert len(doc.part_temps) == 1432
    assert len(doc.part_volts) == 2310

    # Clear
    doc.part_temps = []
    doc.part_volts = []

    candidate_extractor_udf = CandidateExtractorUDF(
        [PartTemp, PartVolt], [temp_throttler, volt_throttler], False, False, True
    )

    doc = candidate_extractor_udf.apply(doc, split=0)

    assert len(doc.part_temps) == 1432
    assert len(doc.part_volts) == 1993
    assert len(doc.parts) == 70
    assert len(doc.volts) == 33
    assert len(doc.temps) == 23
コード例 #8
0
def main(
    conn_string,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(
        session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs
    )
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")

    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    start = timer()

    Thumbnails = mention_subclass("Thumbnails")

    thumbnails_img = MentionFigures()

    class HasFigures(_Matcher):
        def _f(self, m):
            file_path = ""
            for prefix in [
                f"{dirname}/data/train/html/",
                f"{dirname}/data/dev/html/",
                f"{dirname}/data/test/html/",
            ]:
                if os.path.exists(prefix + m.figure.url):
                    file_path = prefix + m.figure.url
            if file_path == "":
                return False
            img = Image.open(file_path)
            width, height = img.size
            min_value = min(width, height)
            return min_value > 50

    mention_extractor = MentionExtractor(
        session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel
    )

    if first_time:
        mention_extractor.apply(docs)

    logger.info("Total Mentions: {}".format(session.query(Mention).count()))

    ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails])

    candidate_extractor = CandidateExtractor(
        session, [ThumbnailLabel], throttlers=[None], parallelism=parallel
    )

    if first_time:
        candidate_extractor.apply(train_docs, split=0)
        candidate_extractor.apply(dev_docs, split=1)
        candidate_extractor.apply(test_docs, split=2)

    train_cands = candidate_extractor.get_candidates(split=0)
    # Sort the dev_cands, which are used for training, for deterministic behavior
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2)

    end = timer()
    logger.warning(f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info("Total train candidate:\t{}".format(len(train_cands[0])))
    logger.info("Total dev candidate:\t{}".format(len(dev_cands[0])))
    logger.info("Total test candidate:\t{}".format(len(test_cands[0])))

    fin = open(f"{dirname}/data/ground_truth.txt", "r")
    gt = set()
    for line in fin:
        gt.add("::".join(line.lower().split()))
    fin.close()

    # Labeling
    start = timer()

    def LF_gt_label(c):
        doc_file_id = (
            f"{c[0].context.figure.document.name.lower()}.pdf::"
            f"{os.path.basename(c[0].context.figure.url.lower())}"
        )
        return TRUE if doc_file_id in gt else FALSE

    gt_dev = [LF_gt_label(cand) for cand in dev_cands[0]]
    gt_test = [LF_gt_label(cand) for cand in test_cands[0]]

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    batch_size = 64
    input_size = 224
    K = 2

    emmental.init(log_dir=Meta.log_path, config=emmental_config)

    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"
    ] = DauphinScheduler(augment_k=K, enlarge=1)

    train_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "train",
        prob_label=True,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        transform_cls=Augmentation(2),
        k=K,
    )

    val_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "valid",
        prob_label=False,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        k=1,
    )

    test_dataset = ThumbnailDataset(
        "Thumbnail",
        test_cands[0],
        gt_test,
        "test",
        prob_label=False,
        prefix=f"{dirname}/data/test/html/",
        input_size=input_size,
        k=1,
    )

    dataloaders = []

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=train_dataset,
            split="train",
            shuffle=True,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=val_dataset,
            split="valid",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=test_dataset,
            split="test",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    model = EmmentalModel(name=f"Thumbnail")
    model.add_task(
        create_task("Thumbnail", n_class=2, model="resnet18", pretrained=True)
    )

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)

    logger.warning("Model Score:")
    logger.warning(f"precision: {scores['Thumbnail/Thumbnail/test/precision']:.3f}")
    logger.warning(f"recall: {scores['Thumbnail/Thumbnail/test/recall']:.3f}")
    logger.warning(f"f1: {scores['Thumbnail/Thumbnail/test/f1']:.3f}")