def load_predictor(det_arch: str, reco_arch: str, device) -> OCRPredictor: """ Args: device is tf.device """ with device: predictor = ocr_predictor( det_arch, reco_arch, pretrained=True, assume_straight_pages=("rotation" not in det_arch) ) return predictor
def main(args): model = ocr_predictor(args.detection, args.recognition, pretrained=True) if args.path.lower().endswith(".pdf"): doc = DocumentFile.from_pdf(args.path) else: doc = DocumentFile.from_images(args.path) out = model(doc) for page, img in zip(out.pages, doc): page.show(img, block=not args.noblock, interactive=not args.static)
def test_zoo_models(det_arch, reco_arch): # Model predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True) # Output checks assert isinstance(predictor, OCRPredictor) doc = [np.zeros((512, 512, 3), dtype=np.uint8)] out = predictor(doc) # Document assert isinstance(out, Document) # The input doc has 1 page assert len(out.pages) == 1 # Dimension check with pytest.raises(ValueError): input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) _ = predictor([input_page])
def main(args): model = ocr_predictor(args.detection, args.recognition, pretrained=True) path = Path(args.path) if path.is_dir(): allowed = (".pdf", ".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp") to_process = [ f for f in path.iterdir() if str(f).lower().endswith(allowed) ] for filename in tqdm(to_process): out_path = path.joinpath(f"{filename}.{args.format}") if out_path.exists(): continue in_path = path.joinpath(filename) # print(in_path) out_str = _process_file(model, in_path, args.format) with open(out_path, "w") as fh: fh.write(out_str) else: out_str = _process_file(model, path, args.format) print(out_str)
def _get_doctr_docs(self, raw_documents: List[Path]): if not hasattr(self, "doctr_model"): self.doctr_model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True) list_doctr_docs = [] for doc in raw_documents: if not doc.exists(): print(f"Doc {doc} could not be found.") continue res_doctr = None try: if doc.suffix == "pdf": doc_doctr = DocumentFile.from_pdf(doc) else: doc_doctr = DocumentFile.from_images(doc) res_doctr = self.doctr_model(doc_doctr) except Exception as e: print(f"Could not analyze document {doc}. Error: {e}") if res_doctr: list_doctr_docs.append(res_doctr) return list_doctr_docs
def test_zoo_models(det_arch, reco_arch): # Model predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True) _test_predictor(predictor) # passing model instance directly det_model = detection.__dict__[det_arch](pretrained=True) reco_model = recognition.__dict__[reco_arch](pretrained=True) predictor = models.ocr_predictor(det_model, reco_model) _test_predictor(predictor) # passing recognition model as detection model with pytest.raises(ValueError): models.ocr_predictor(det_arch=reco_model, pretrained=True) # passing detection model as recognition model with pytest.raises(ValueError): models.ocr_predictor(reco_arch=det_model, pretrained=True)
# Copyright (C) 2021-2022, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. import tensorflow as tf gpu_devices = tf.config.experimental.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) from doctr.models import ocr_predictor predictor = ocr_predictor(pretrained=True) det_predictor = predictor.det_predictor reco_predictor = predictor.reco_predictor
from doctr.models import ocr_predictor if __name__ == "__main__": doctr_model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
def main(args): if not args.rotation: args.eval_straight = True predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size, assume_straight_pages=not args.rotation) if args.img_folder and args.label_file: testset = datasets.OCRDataset( img_folder=args.img_folder, label_file=args.label_file, ) sets = [testset] else: train_set = datasets.__dict__[args.dataset]( train=True, download=True, use_polygons=not args.eval_straight) val_set = datasets.__dict__[args.dataset]( train=False, download=True, use_polygons=not args.eval_straight) sets = [train_set, val_set] reco_metric = TextMatch() if args.mask_shape: det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape)) e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape)) else: det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight) e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight) sample_idx = 0 extraction_fn = extract_crops if args.eval_straight else extract_rcrops for dataset in sets: for page, target in tqdm(dataset): # GT gt_boxes = target['boxes'] gt_labels = target['labels'] if args.img_folder and args.label_file: x, y, w, h = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2], gt_boxes[:, 3] xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1) xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1) gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1) # Forward if is_tf_available(): out = predictor(page[None, ...]) crops = extraction_fn(page, gt_boxes) reco_out = predictor.reco_predictor(crops) else: with torch.no_grad(): out = predictor(page[None, ...]) # We directly crop on PyTorch tensors, which are in channels_first crops = extraction_fn(page, gt_boxes, channels_last=False) reco_out = predictor.reco_predictor(crops) if len(reco_out): reco_words, _ = zip(*reco_out) else: reco_words = [] # Unpack preds pred_boxes = [] pred_labels = [] for page in out.pages: height, width = page.dimensions for block in page.blocks: for line in block.lines: for word in line.words: if not args.rotation: (a, b), (c, d) = word.geometry else: [x1, y1], [x2, y2], [x3, y3], [x4, y4], = word.geometry if gt_boxes.dtype == int: if not args.rotation: pred_boxes.append([ int(a * width), int(b * height), int(c * width), int(d * height) ]) else: if args.eval_straight: pred_boxes.append([ int(width * min(x1, x2, x3, x4)), int(height * min(y1, y2, y3, y4)), int(width * max(x1, x2, x3, x4)), int(height * max(y1, y2, y3, y4)), ]) else: pred_boxes.append([ [ int(x1 * width), int(y1 * height) ], [ int(x2 * width), int(y2 * height) ], [ int(x3 * width), int(y3 * height) ], [ int(x4 * width), int(y4 * height) ], ]) else: if not args.rotation: pred_boxes.append([a, b, c, d]) else: if args.eval_straight: pred_boxes.append([ min(x1, x2, x3, x4), min(y1, y2, y3, y4), max(x1, x2, x3, x4), max(y1, y2, y3, y4), ]) else: pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) pred_labels.append(word.value) # Update the metric det_metric.update(gt_boxes, np.asarray(pred_boxes)) reco_metric.update(gt_labels, reco_words) e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels, pred_labels) # Loop break sample_idx += 1 if isinstance(args.samples, int) and args.samples == sample_idx: break if isinstance(args.samples, int) and args.samples == sample_idx: break # Unpack aggregated metrics print(f"Model Evaluation (model= {args.detection} + {args.recognition}, " f"dataset={'OCRDataset' if args.img_folder else args.dataset})") recall, precision, mean_iou = det_metric.summary() print( f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}" ) acc = reco_metric.summary() print( f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})" ) recall, precision, mean_iou = e2e_metric.summary() print( f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), " f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}" )
def main(): # Wide mode st.set_page_config(layout="wide") # Designing the interface st.title("docTR: Document Text Recognition") # For newline st.write('\n') # Instructions st.markdown( "*Hint: click on the top-right corner of an image to enlarge it!*") # Set the columns cols = st.columns((1, 1, 1, 1)) cols[0].subheader("Input page") cols[1].subheader("Segmentation heatmap") cols[2].subheader("OCR output") cols[3].subheader("Page reconstitution") # Sidebar # File selection st.sidebar.title("Document selection") # Disabling warning st.set_option('deprecation.showfileUploaderEncoding', False) # Choose your own image uploaded_file = st.sidebar.file_uploader( "Upload files", type=['pdf', 'png', 'jpeg', 'jpg']) if uploaded_file is not None: if uploaded_file.name.endswith('.pdf'): doc = DocumentFile.from_pdf(uploaded_file.read()).as_images() else: doc = DocumentFile.from_images(uploaded_file.read()) page_idx = st.sidebar.selectbox( "Page selection", [idx + 1 for idx in range(len(doc))]) - 1 cols[0].image(doc[page_idx]) # Model selection st.sidebar.title("Model selection") det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS) reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS) # For newline st.sidebar.write('\n') if st.sidebar.button("Analyze page"): if uploaded_file is None: st.sidebar.write("Please upload a document") else: with st.spinner('Loading model...'): predictor = ocr_predictor(det_arch, reco_arch, pretrained=True) with st.spinner('Analyzing...'): # Forward the image to the model processed_batches = predictor.det_predictor.pre_processor( [doc[page_idx]]) out = predictor.det_predictor.model(processed_batches[0], return_model_output=True) seg_map = out["out_map"] seg_map = tf.squeeze(seg_map[0, ...], axis=[2]) seg_map = cv2.resize( seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]), interpolation=cv2.INTER_LINEAR) # Plot the raw heatmap fig, ax = plt.subplots() ax.imshow(seg_map) ax.axis('off') cols[1].pyplot(fig) # Plot OCR output out = predictor([doc[page_idx]]) fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False) cols[2].pyplot(fig) # Page reconsitution under input page page_export = out.pages[0].export() img = out.pages[0].synthesize() cols[3].image(img, clamp=True) # Display JSON st.markdown("\nHere are your analysis results in JSON format:") st.json(page_export)