예제 #1
0
def load_predictor(det_arch: str, reco_arch: str, device) -> OCRPredictor:
    """
    Args:
        device is tf.device
    """
    with device:
        predictor = ocr_predictor(
            det_arch, reco_arch, pretrained=True, assume_straight_pages=("rotation" not in det_arch)
        )
    return predictor
예제 #2
0
파일: analyze.py 프로젝트: mindee/doctr
def main(args):

    model = ocr_predictor(args.detection, args.recognition, pretrained=True)

    if args.path.lower().endswith(".pdf"):
        doc = DocumentFile.from_pdf(args.path)
    else:
        doc = DocumentFile.from_images(args.path)

    out = model(doc)

    for page, img in zip(out.pages, doc):
        page.show(img, block=not args.noblock, interactive=not args.static)
예제 #3
0
def test_zoo_models(det_arch, reco_arch):
    # Model
    predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True)
    # Output checks
    assert isinstance(predictor, OCRPredictor)

    doc = [np.zeros((512, 512, 3), dtype=np.uint8)]
    out = predictor(doc)
    # Document
    assert isinstance(out, Document)

    # The input doc has 1 page
    assert len(out.pages) == 1
    # Dimension check
    with pytest.raises(ValueError):
        input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8)
        _ = predictor([input_page])
예제 #4
0
파일: detect_text.py 프로젝트: mindee/doctr
def main(args):
    model = ocr_predictor(args.detection, args.recognition, pretrained=True)
    path = Path(args.path)
    if path.is_dir():
        allowed = (".pdf", ".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp")
        to_process = [
            f for f in path.iterdir() if str(f).lower().endswith(allowed)
        ]
        for filename in tqdm(to_process):
            out_path = path.joinpath(f"{filename}.{args.format}")
            if out_path.exists():
                continue
            in_path = path.joinpath(filename)
            # print(in_path)
            out_str = _process_file(model, in_path, args.format)
            with open(out_path, "w") as fh:
                fh.write(out_str)
    else:
        out_str = _process_file(model, path, args.format)
        print(out_str)
예제 #5
0
    def _get_doctr_docs(self, raw_documents: List[Path]):
        if not hasattr(self, "doctr_model"):
            self.doctr_model = ocr_predictor(det_arch='db_resnet50',
                                             reco_arch='crnn_vgg16_bn',
                                             pretrained=True)
        list_doctr_docs = []
        for doc in raw_documents:
            if not doc.exists():
                print(f"Doc {doc} could not be found.")
                continue
            res_doctr = None
            try:
                if doc.suffix == "pdf":
                    doc_doctr = DocumentFile.from_pdf(doc)
                else:
                    doc_doctr = DocumentFile.from_images(doc)
                res_doctr = self.doctr_model(doc_doctr)
            except Exception as e:
                print(f"Could not analyze document {doc}. Error: {e}")
            if res_doctr:
                list_doctr_docs.append(res_doctr)

        return list_doctr_docs
예제 #6
0
def test_zoo_models(det_arch, reco_arch):
    # Model
    predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True)
    _test_predictor(predictor)

    # passing model instance directly
    det_model = detection.__dict__[det_arch](pretrained=True)
    reco_model = recognition.__dict__[reco_arch](pretrained=True)
    predictor = models.ocr_predictor(det_model, reco_model)
    _test_predictor(predictor)

    # passing recognition model as detection model
    with pytest.raises(ValueError):
        models.ocr_predictor(det_arch=reco_model, pretrained=True)

    # passing detection model as recognition model
    with pytest.raises(ValueError):
        models.ocr_predictor(reco_arch=det_model, pretrained=True)
예제 #7
0
파일: vision.py 프로젝트: mindee/doctr
# Copyright (C) 2021-2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
if any(gpu_devices):
    tf.config.experimental.set_memory_growth(gpu_devices[0], True)

from doctr.models import ocr_predictor

predictor = ocr_predictor(pretrained=True)
det_predictor = predictor.det_predictor
reco_predictor = predictor.reco_predictor
예제 #8
0
from doctr.models import ocr_predictor

if __name__ == "__main__":
    doctr_model = ocr_predictor(det_arch='db_resnet50',
                                reco_arch='crnn_vgg16_bn',
                                pretrained=True)
예제 #9
0
def main(args):

    if not args.rotation:
        args.eval_straight = True

    predictor = ocr_predictor(args.detection,
                              args.recognition,
                              pretrained=True,
                              reco_bs=args.batch_size,
                              assume_straight_pages=not args.rotation)

    if args.img_folder and args.label_file:
        testset = datasets.OCRDataset(
            img_folder=args.img_folder,
            label_file=args.label_file,
        )
        sets = [testset]
    else:
        train_set = datasets.__dict__[args.dataset](
            train=True, download=True, use_polygons=not args.eval_straight)
        val_set = datasets.__dict__[args.dataset](
            train=False, download=True, use_polygons=not args.eval_straight)
        sets = [train_set, val_set]

    reco_metric = TextMatch()
    if args.mask_shape:
        det_metric = LocalizationConfusion(iou_thresh=args.iou,
                                           use_polygons=not args.eval_straight,
                                           mask_shape=(args.mask_shape,
                                                       args.mask_shape))
        e2e_metric = OCRMetric(iou_thresh=args.iou,
                               use_polygons=not args.eval_straight,
                               mask_shape=(args.mask_shape, args.mask_shape))
    else:
        det_metric = LocalizationConfusion(iou_thresh=args.iou,
                                           use_polygons=not args.eval_straight)
        e2e_metric = OCRMetric(iou_thresh=args.iou,
                               use_polygons=not args.eval_straight)

    sample_idx = 0
    extraction_fn = extract_crops if args.eval_straight else extract_rcrops

    for dataset in sets:
        for page, target in tqdm(dataset):
            # GT
            gt_boxes = target['boxes']
            gt_labels = target['labels']

            if args.img_folder and args.label_file:
                x, y, w, h = gt_boxes[:,
                                      0], gt_boxes[:,
                                                   1], gt_boxes[:,
                                                                2], gt_boxes[:,
                                                                             3]
                xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1)
                xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1)
                gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1)

            # Forward
            if is_tf_available():
                out = predictor(page[None, ...])
                crops = extraction_fn(page, gt_boxes)
                reco_out = predictor.reco_predictor(crops)
            else:
                with torch.no_grad():
                    out = predictor(page[None, ...])
                    # We directly crop on PyTorch tensors, which are in channels_first
                    crops = extraction_fn(page, gt_boxes, channels_last=False)
                    reco_out = predictor.reco_predictor(crops)

            if len(reco_out):
                reco_words, _ = zip(*reco_out)
            else:
                reco_words = []

            # Unpack preds
            pred_boxes = []
            pred_labels = []
            for page in out.pages:
                height, width = page.dimensions
                for block in page.blocks:
                    for line in block.lines:
                        for word in line.words:
                            if not args.rotation:
                                (a, b), (c, d) = word.geometry
                            else:
                                [x1, y1], [x2, y2], [x3,
                                                     y3], [x4,
                                                           y4], = word.geometry
                            if gt_boxes.dtype == int:
                                if not args.rotation:
                                    pred_boxes.append([
                                        int(a * width),
                                        int(b * height),
                                        int(c * width),
                                        int(d * height)
                                    ])
                                else:
                                    if args.eval_straight:
                                        pred_boxes.append([
                                            int(width * min(x1, x2, x3, x4)),
                                            int(height * min(y1, y2, y3, y4)),
                                            int(width * max(x1, x2, x3, x4)),
                                            int(height * max(y1, y2, y3, y4)),
                                        ])
                                    else:
                                        pred_boxes.append([
                                            [
                                                int(x1 * width),
                                                int(y1 * height)
                                            ],
                                            [
                                                int(x2 * width),
                                                int(y2 * height)
                                            ],
                                            [
                                                int(x3 * width),
                                                int(y3 * height)
                                            ],
                                            [
                                                int(x4 * width),
                                                int(y4 * height)
                                            ],
                                        ])
                            else:
                                if not args.rotation:
                                    pred_boxes.append([a, b, c, d])
                                else:
                                    if args.eval_straight:
                                        pred_boxes.append([
                                            min(x1, x2, x3, x4),
                                            min(y1, y2, y3, y4),
                                            max(x1, x2, x3, x4),
                                            max(y1, y2, y3, y4),
                                        ])
                                    else:
                                        pred_boxes.append([[x1, y1], [x2, y2],
                                                           [x3, y3], [x4, y4]])
                            pred_labels.append(word.value)

            # Update the metric
            det_metric.update(gt_boxes, np.asarray(pred_boxes))
            reco_metric.update(gt_labels, reco_words)
            e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels,
                              pred_labels)

            # Loop break
            sample_idx += 1
            if isinstance(args.samples, int) and args.samples == sample_idx:
                break
        if isinstance(args.samples, int) and args.samples == sample_idx:
            break

    # Unpack aggregated metrics
    print(f"Model Evaluation (model= {args.detection} + {args.recognition}, "
          f"dataset={'OCRDataset' if args.img_folder else args.dataset})")
    recall, precision, mean_iou = det_metric.summary()
    print(
        f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}"
    )
    acc = reco_metric.summary()
    print(
        f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})"
    )
    recall, precision, mean_iou = e2e_metric.summary()
    print(
        f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), "
        f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}"
    )
예제 #10
0
def main():

    # Wide mode
    st.set_page_config(layout="wide")

    # Designing the interface
    st.title("docTR: Document Text Recognition")
    # For newline
    st.write('\n')
    # Instructions
    st.markdown(
        "*Hint: click on the top-right corner of an image to enlarge it!*")
    # Set the columns
    cols = st.columns((1, 1, 1, 1))
    cols[0].subheader("Input page")
    cols[1].subheader("Segmentation heatmap")
    cols[2].subheader("OCR output")
    cols[3].subheader("Page reconstitution")

    # Sidebar
    # File selection
    st.sidebar.title("Document selection")
    # Disabling warning
    st.set_option('deprecation.showfileUploaderEncoding', False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader(
        "Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
    if uploaded_file is not None:
        if uploaded_file.name.endswith('.pdf'):
            doc = DocumentFile.from_pdf(uploaded_file.read()).as_images()
        else:
            doc = DocumentFile.from_images(uploaded_file.read())
        page_idx = st.sidebar.selectbox(
            "Page selection", [idx + 1 for idx in range(len(doc))]) - 1
        cols[0].image(doc[page_idx])

    # Model selection
    st.sidebar.title("Model selection")
    det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
    reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)

    # For newline
    st.sidebar.write('\n')

    if st.sidebar.button("Analyze page"):

        if uploaded_file is None:
            st.sidebar.write("Please upload a document")

        else:
            with st.spinner('Loading model...'):
                predictor = ocr_predictor(det_arch, reco_arch, pretrained=True)

            with st.spinner('Analyzing...'):

                # Forward the image to the model
                processed_batches = predictor.det_predictor.pre_processor(
                    [doc[page_idx]])
                out = predictor.det_predictor.model(processed_batches[0],
                                                    return_model_output=True)
                seg_map = out["out_map"]
                seg_map = tf.squeeze(seg_map[0, ...], axis=[2])
                seg_map = cv2.resize(
                    seg_map.numpy(),
                    (doc[page_idx].shape[1], doc[page_idx].shape[0]),
                    interpolation=cv2.INTER_LINEAR)
                # Plot the raw heatmap
                fig, ax = plt.subplots()
                ax.imshow(seg_map)
                ax.axis('off')
                cols[1].pyplot(fig)

                # Plot OCR output
                out = predictor([doc[page_idx]])
                fig = visualize_page(out.pages[0].export(),
                                     doc[page_idx],
                                     interactive=False)
                cols[2].pyplot(fig)

                # Page reconsitution under input page
                page_export = out.pages[0].export()
                img = out.pages[0].synthesize()
                cols[3].image(img, clamp=True)

                # Display JSON
                st.markdown("\nHere are your analysis results in JSON format:")
                st.json(page_export)