コード例 #1
0
def test_vision_model():
    """Check if the vision model runs and returns results in expected format."""
    pdf_file = "tests/input/paleo.pdf"
    model_path = "tests/input/paleo_visual_model.h5"
    model = load_model("vision", model_path)
    page_num = 0
    image, pred = predict_heatmap(
        pdf_file, page_num, model
    )  # index start at 0 with wand
    assert image.shape == (448, 448, 3)
    assert pred.shape == (448, 448)
コード例 #2
0
    def get_tree_structure(self, model_type, model, favor_figures) -> Dict[str, Any]:
        tables = {}
        # use vision to get tables
        if model_type == "vision":
            from pdftotree.visual.visual_utils import get_bboxes, predict_heatmap

            for page_num in self.elems.keys():
                page_width = int(self.elems[page_num].layout.width)
                page_height = int(self.elems[page_num].layout.height)
                image, pred = predict_heatmap(
                    self.pdf_file, page_num - 1, model
                )  # index start at 0 with wand
                bboxes, _ = get_bboxes(image, pred)
                tables[page_num] = [
                    (page_num, page_width, page_height)
                    + (top, left, top + height, left + width)
                    for (left, top, width, height) in bboxes
                ]

        # use ML to get tables
        elif model_type == "ml":
            for page_num in self.elems.keys():
                t_cands, cand_feats = self.get_candidates_and_features_page_num(
                    page_num
                )
                tables[page_num] = []
                if len(cand_feats) != 0:
                    table_predictions = model.predict(cand_feats)
                    tables[page_num] = [
                        t_cands[i]
                        for i in range(len(t_cands))
                        if table_predictions[i] > 0.5
                    ]

        # use heuristics to get tables if no model_type is provided
        else:
            for page_num in self.elems.keys():
                tables[page_num] = self.get_tables_page_num(page_num)

        # Manage References - indicator to indicate if reference has been seen
        ref_page_seen = False
        for page_num in self.elems.keys():
            # Get Tree Structure for this page
            self.tree[page_num], ref_page_seen = parse_tree_structure(
                self.elems[page_num],
                self.font_stats[page_num],
                page_num,
                ref_page_seen,
                tables[page_num],
                favor_figures,
            )
        return self.tree