Ejemplo n.º 1
0
Archivo: zoo.py Proyecto: mindee/doctr
def _predictor(arch: Any,
               pretrained: bool,
               assume_straight_pages: bool = True,
               **kwargs: Any) -> DetectionPredictor:

    if isinstance(arch, str):
        if arch not in ARCHS + ROT_ARCHS:
            raise ValueError(f"unknown architecture '{arch}'")

        if arch not in ROT_ARCHS and not assume_straight_pages:
            raise AssertionError(
                "You are trying to use a model trained on straight pages while not assuming"
                " your pages are straight. If you have only straight documents, don't pass"
                " assume_straight_pages=False, otherwise you should use one of these archs:"
                f"{ROT_ARCHS}")

        _model = detection.__dict__[arch](
            pretrained=pretrained, assume_straight_pages=assume_straight_pages)
    else:
        if not isinstance(arch, (detection.DBNet, detection.LinkNet)):
            raise ValueError(f"unknown architecture: {type(arch)}")

        _model = arch
        _model.assume_straight_pages = assume_straight_pages

    kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
    kwargs["std"] = kwargs.get("std", _model.cfg["std"])
    kwargs["batch_size"] = kwargs.get("batch_size", 1)
    predictor = DetectionPredictor(
        PreProcessor(
            _model.cfg["input_shape"][:-1]
            if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
        _model,
    )
    return predictor
Ejemplo n.º 2
0
def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str,
                                      task: str) -> None:
    """Save model and config to disk for pushing to huggingface hub

    Args:
        model: TF or PyTorch model to be saved
        save_dir: directory to save model and config
        arch: architecture name
        task: task name
    """
    save_directory = Path(save_dir)

    if is_torch_available():
        weights_path = save_directory / "pytorch_model.bin"
        torch.save(model.state_dict(), weights_path)
    elif is_tf_available():
        weights_path = save_directory / "tf_model" / "weights"
        model.save_weights(str(weights_path))

    config_path = save_directory / "config.json"

    # add model configuration
    model_config = model.cfg
    model_config["arch"] = arch
    model_config["task"] = task

    with config_path.open("w") as f:
        json.dump(model_config, f, indent=2, ensure_ascii=False)
Ejemplo n.º 3
0
def _predictor(arch: str,
               pretrained: bool,
               assume_straight_pages: bool = True,
               **kwargs: Any) -> DetectionPredictor:

    if arch not in ARCHS:
        raise ValueError(f"unknown architecture '{arch}'")

    if arch not in ROT_ARCHS and not assume_straight_pages:
        raise AssertionError(
            "You are trying to use a model trained on straight pages while not assuming"
            " your pages are straight. If you have only straight documents, don't pass"
            f" assume_straight_pages=False, otherwise you should use one of these archs: {ROT_ARCHS}"
        )

    # Detection
    _model = detection.__dict__[arch](
        pretrained=pretrained, assume_straight_pages=assume_straight_pages)
    kwargs['mean'] = kwargs.get('mean', _model.cfg['mean'])
    kwargs['std'] = kwargs.get('std', _model.cfg['std'])
    kwargs['batch_size'] = kwargs.get('batch_size', 1)
    predictor = DetectionPredictor(
        PreProcessor(
            _model.cfg['input_shape'][:-1] if is_tf_available() else
            _model.cfg['input_shape'][1:], **kwargs), _model)
    return predictor
Ejemplo n.º 4
0
def _predictor(arch: str, pretrained: bool,
               **kwargs: Any) -> RecognitionPredictor:

    if arch not in ARCHS:
        raise ValueError(f"unknown architecture '{arch}'")

    _model = recognition.__dict__[arch](pretrained=pretrained)
    kwargs['mean'] = kwargs.get('mean', _model.cfg['mean'])
    kwargs['std'] = kwargs.get('std', _model.cfg['std'])
    kwargs['batch_size'] = kwargs.get('batch_size', 32)
    input_shape = _model.cfg['input_shape'][:2] if is_tf_available(
    ) else _model.cfg['input_shape'][-2:]
    predictor = RecognitionPredictor(
        PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs),
        _model)

    return predictor
Ejemplo n.º 5
0
Archivo: zoo.py Proyecto: mindee/doctr
def _crop_orientation_predictor(arch: str, pretrained: bool,
                                **kwargs: Any) -> CropOrientationPredictor:

    if arch not in ORIENTATION_ARCHS:
        raise ValueError(f"unknown architecture '{arch}'")

    # Load directly classifier from backbone
    _model = classification.__dict__[arch](pretrained=pretrained)
    kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
    kwargs["std"] = kwargs.get("std", _model.cfg["std"])
    kwargs["batch_size"] = kwargs.get("batch_size", 64)
    input_shape = _model.cfg["input_shape"][:-1] if is_tf_available(
    ) else _model.cfg["input_shape"][1:]
    predictor = CropOrientationPredictor(
        PreProcessor(input_shape,
                     preserve_aspect_ratio=True,
                     symmetric_pad=True,
                     **kwargs), _model)
    return predictor
Ejemplo n.º 6
0
def _predictor(arch: Any, pretrained: bool,
               **kwargs: Any) -> RecognitionPredictor:

    if isinstance(arch, str):
        if arch not in ARCHS:
            raise ValueError(f"unknown architecture '{arch}'")

        _model = recognition.__dict__[arch](pretrained=pretrained)
    else:
        if not isinstance(
                arch, (recognition.CRNN, recognition.SAR, recognition.MASTER)):
            raise ValueError(f"unknown architecture: {type(arch)}")
        _model = arch

    kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
    kwargs["std"] = kwargs.get("std", _model.cfg["std"])
    kwargs["batch_size"] = kwargs.get("batch_size", 32)
    input_shape = _model.cfg["input_shape"][:2] if is_tf_available(
    ) else _model.cfg["input_shape"][-2:]
    predictor = RecognitionPredictor(
        PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs),
        _model)

    return predictor
Ejemplo n.º 7
0
from doctr.file_utils import is_tf_available

from .generator import *
from .cord import *
from .detection import *
from .doc_artefacts import *
from .funsd import *
from .ic03 import *
from .ic13 import *
from .iiit5k import *
from .imgur5k import *
from .ocr import *
from .recognition import *
from .sroie import *
from .svhn import *
from .svt import *
from .synthtext import *
from .utils import *
from .vocabs import *

if is_tf_available():
    from .loader import *
Ejemplo n.º 8
0
def main(args):

    if not args.rotation:
        args.eval_straight = True

    predictor = ocr_predictor(args.detection,
                              args.recognition,
                              pretrained=True,
                              reco_bs=args.batch_size,
                              assume_straight_pages=not args.rotation)

    if args.img_folder and args.label_file:
        testset = datasets.OCRDataset(
            img_folder=args.img_folder,
            label_file=args.label_file,
        )
        sets = [testset]
    else:
        train_set = datasets.__dict__[args.dataset](
            train=True, download=True, use_polygons=not args.eval_straight)
        val_set = datasets.__dict__[args.dataset](
            train=False, download=True, use_polygons=not args.eval_straight)
        sets = [train_set, val_set]

    reco_metric = TextMatch()
    if args.mask_shape:
        det_metric = LocalizationConfusion(iou_thresh=args.iou,
                                           use_polygons=not args.eval_straight,
                                           mask_shape=(args.mask_shape,
                                                       args.mask_shape))
        e2e_metric = OCRMetric(iou_thresh=args.iou,
                               use_polygons=not args.eval_straight,
                               mask_shape=(args.mask_shape, args.mask_shape))
    else:
        det_metric = LocalizationConfusion(iou_thresh=args.iou,
                                           use_polygons=not args.eval_straight)
        e2e_metric = OCRMetric(iou_thresh=args.iou,
                               use_polygons=not args.eval_straight)

    sample_idx = 0
    extraction_fn = extract_crops if args.eval_straight else extract_rcrops

    for dataset in sets:
        for page, target in tqdm(dataset):
            # GT
            gt_boxes = target['boxes']
            gt_labels = target['labels']

            if args.img_folder and args.label_file:
                x, y, w, h = gt_boxes[:,
                                      0], gt_boxes[:,
                                                   1], gt_boxes[:,
                                                                2], gt_boxes[:,
                                                                             3]
                xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1)
                xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1)
                gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1)

            # Forward
            if is_tf_available():
                out = predictor(page[None, ...])
                crops = extraction_fn(page, gt_boxes)
                reco_out = predictor.reco_predictor(crops)
            else:
                with torch.no_grad():
                    out = predictor(page[None, ...])
                    # We directly crop on PyTorch tensors, which are in channels_first
                    crops = extraction_fn(page, gt_boxes, channels_last=False)
                    reco_out = predictor.reco_predictor(crops)

            if len(reco_out):
                reco_words, _ = zip(*reco_out)
            else:
                reco_words = []

            # Unpack preds
            pred_boxes = []
            pred_labels = []
            for page in out.pages:
                height, width = page.dimensions
                for block in page.blocks:
                    for line in block.lines:
                        for word in line.words:
                            if not args.rotation:
                                (a, b), (c, d) = word.geometry
                            else:
                                [x1, y1], [x2, y2], [x3,
                                                     y3], [x4,
                                                           y4], = word.geometry
                            if gt_boxes.dtype == int:
                                if not args.rotation:
                                    pred_boxes.append([
                                        int(a * width),
                                        int(b * height),
                                        int(c * width),
                                        int(d * height)
                                    ])
                                else:
                                    if args.eval_straight:
                                        pred_boxes.append([
                                            int(width * min(x1, x2, x3, x4)),
                                            int(height * min(y1, y2, y3, y4)),
                                            int(width * max(x1, x2, x3, x4)),
                                            int(height * max(y1, y2, y3, y4)),
                                        ])
                                    else:
                                        pred_boxes.append([
                                            [
                                                int(x1 * width),
                                                int(y1 * height)
                                            ],
                                            [
                                                int(x2 * width),
                                                int(y2 * height)
                                            ],
                                            [
                                                int(x3 * width),
                                                int(y3 * height)
                                            ],
                                            [
                                                int(x4 * width),
                                                int(y4 * height)
                                            ],
                                        ])
                            else:
                                if not args.rotation:
                                    pred_boxes.append([a, b, c, d])
                                else:
                                    if args.eval_straight:
                                        pred_boxes.append([
                                            min(x1, x2, x3, x4),
                                            min(y1, y2, y3, y4),
                                            max(x1, x2, x3, x4),
                                            max(y1, y2, y3, y4),
                                        ])
                                    else:
                                        pred_boxes.append([[x1, y1], [x2, y2],
                                                           [x3, y3], [x4, y4]])
                            pred_labels.append(word.value)

            # Update the metric
            det_metric.update(gt_boxes, np.asarray(pred_boxes))
            reco_metric.update(gt_labels, reco_words)
            e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels,
                              pred_labels)

            # Loop break
            sample_idx += 1
            if isinstance(args.samples, int) and args.samples == sample_idx:
                break
        if isinstance(args.samples, int) and args.samples == sample_idx:
            break

    # Unpack aggregated metrics
    print(f"Model Evaluation (model= {args.detection} + {args.recognition}, "
          f"dataset={'OCRDataset' if args.img_folder else args.dataset})")
    recall, precision, mean_iou = det_metric.summary()
    print(
        f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}"
    )
    acc = reco_metric.summary()
    print(
        f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})"
    )
    recall, precision, mean_iou = e2e_metric.summary()
    print(
        f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), "
        f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}"
    )
Ejemplo n.º 9
0
Archivo: base.py Proyecto: mindee/doctr
    def build_target(
        self,
        target: List[np.ndarray],
        output_shape: Tuple[int, int],
    ) -> Tuple[np.ndarray, np.ndarray]:

        if any(t.dtype != np.float32 for t in target):
            raise AssertionError(
                "the expected dtype of target 'boxes' entry is 'np.float32'.")
        if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target):
            raise ValueError(
                "the 'boxes' entry of the target is expected to take values between 0 & 1."
            )

        h, w = output_shape
        target_shape = (len(target), h, w, 1)

        seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
        seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)

        for idx, _target in enumerate(target):
            # Draw each polygon on gt
            if _target.shape[0] == 0:
                # Empty image, full masked
                seg_mask[idx] = False

            # Absolute bounding boxes
            abs_boxes = _target.copy()

            if abs_boxes.ndim == 3:
                abs_boxes[:, :, 0] *= w
                abs_boxes[:, :, 1] *= h
                polys = abs_boxes
                boxes_size = np.linalg.norm(abs_boxes[:, 2, :] -
                                            abs_boxes[:, 0, :],
                                            axis=-1)
                abs_boxes = np.concatenate(
                    (abs_boxes.min(1), abs_boxes.max(1)),
                    -1).round().astype(np.int32)
            else:
                abs_boxes[:, [0, 2]] *= w
                abs_boxes[:, [1, 3]] *= h
                abs_boxes = abs_boxes.round().astype(np.int32)
                polys = np.stack(
                    [
                        abs_boxes[:, [0, 1]],
                        abs_boxes[:, [0, 3]],
                        abs_boxes[:, [2, 3]],
                        abs_boxes[:, [2, 1]],
                    ],
                    axis=1,
                )
                boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0],
                                        abs_boxes[:, 3] - abs_boxes[:, 1])

            for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
                # Mask boxes that are too small
                if box_size < self.min_size_box:
                    seg_mask[idx, box[1]:box[3] + 1, box[0]:box[2] + 1] = False
                    continue

                # Negative shrink for gt, as described in paper
                polygon = Polygon(poly)
                distance = polygon.area * (
                    1 - np.power(self.shrink_ratio, 2)) / polygon.length
                subject = [tuple(coor) for coor in poly]
                padding = pyclipper.PyclipperOffset()
                padding.AddPath(subject, pyclipper.JT_ROUND,
                                pyclipper.ET_CLOSEDPOLYGON)
                shrunken = padding.Execute(-distance)

                # Draw polygon on gt if it is valid
                if len(shrunken) == 0:
                    seg_mask[idx, box[1]:box[3] + 1, box[0]:box[2] + 1] = False
                    continue
                shrunken = np.array(shrunken[0]).reshape(-1, 2)
                if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
                    seg_mask[idx, box[1]:box[3] + 1, box[0]:box[2] + 1] = False
                    continue
                cv2.fillPoly(seg_target[idx], [shrunken.astype(np.int32)], 1)

        # Don't forget to switch back to channel first if PyTorch is used
        if not is_tf_available():
            seg_target = seg_target.transpose(0, 3, 1, 2)
            seg_mask = seg_mask.transpose(0, 3, 1, 2)

        return seg_target, seg_mask
Ejemplo n.º 10
0
def main(det_archs, reco_archs):
    """Build a streamlit layout"""

    # Wide mode
    st.set_page_config(layout="wide")

    # Designing the interface
    st.title("docTR: Document Text Recognition")
    # For newline
    st.write("\n")
    # Instructions
    st.markdown(
        "*Hint: click on the top-right corner of an image to enlarge it!*")
    # Set the columns
    cols = st.columns((1, 1, 1, 1))
    cols[0].subheader("Input page")
    cols[1].subheader("Segmentation heatmap")
    cols[2].subheader("OCR output")
    cols[3].subheader("Page reconstitution")

    # Sidebar
    # File selection
    st.sidebar.title("Document selection")
    # Disabling warning
    st.set_option("deprecation.showfileUploaderEncoding", False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader(
        "Upload files", type=["pdf", "png", "jpeg", "jpg"])
    if uploaded_file is not None:
        if uploaded_file.name.endswith(".pdf"):
            doc = DocumentFile.from_pdf(uploaded_file.read())
        else:
            doc = DocumentFile.from_images(uploaded_file.read())
        page_idx = st.sidebar.selectbox(
            "Page selection", [idx + 1 for idx in range(len(doc))]) - 1
        page = doc[page_idx]
        cols[0].image(page)

    # Model selection
    st.sidebar.title("Model selection")
    st.sidebar.markdown("**Backend**: " +
                        ("TensorFlow" if is_tf_available() else "PyTorch"))
    det_arch = st.sidebar.selectbox("Text detection model", det_archs)
    reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)

    # For newline
    st.sidebar.write("\n")

    if st.sidebar.button("Analyze page"):

        if uploaded_file is None:
            st.sidebar.write("Please upload a document")

        else:
            with st.spinner("Loading model..."):
                predictor = load_predictor(det_arch, reco_arch, forward_device)

            with st.spinner("Analyzing..."):

                # Forward the image to the model
                seg_map = forward_image(predictor, page, forward_device)
                seg_map = np.squeeze(seg_map)
                seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]),
                                     interpolation=cv2.INTER_LINEAR)

                # Plot the raw heatmap
                fig, ax = plt.subplots()
                ax.imshow(seg_map)
                ax.axis("off")
                cols[1].pyplot(fig)

                # Plot OCR output
                out = predictor([page])
                fig = visualize_page(out.pages[0].export(),
                                     page,
                                     interactive=False)
                cols[2].pyplot(fig)

                # Page reconsitution under input page
                page_export = out.pages[0].export()
                if "rotation" not in det_arch:
                    img = out.pages[0].synthesize()
                    cols[3].image(img, clamp=True)

                # Display JSON
                st.markdown("\nHere are your analysis results in JSON format:")
                st.json(page_export)
Ejemplo n.º 11
0
from doctr.file_utils import is_tf_available, is_torch_available

if not is_tf_available() and is_torch_available():
    from .pytorch import *