def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor: if isinstance(arch, str): if arch not in ARCHS + ROT_ARCHS: raise ValueError(f"unknown architecture '{arch}'") if arch not in ROT_ARCHS and not assume_straight_pages: raise AssertionError( "You are trying to use a model trained on straight pages while not assuming" " your pages are straight. If you have only straight documents, don't pass" " assume_straight_pages=False, otherwise you should use one of these archs:" f"{ROT_ARCHS}") _model = detection.__dict__[arch]( pretrained=pretrained, assume_straight_pages=assume_straight_pages) else: if not isinstance(arch, (detection.DBNet, detection.LinkNet)): raise ValueError(f"unknown architecture: {type(arch)}") _model = arch _model.assume_straight_pages = assume_straight_pages kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 1) predictor = DetectionPredictor( PreProcessor( _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs), _model, ) return predictor
def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None: """Save model and config to disk for pushing to huggingface hub Args: model: TF or PyTorch model to be saved save_dir: directory to save model and config arch: architecture name task: task name """ save_directory = Path(save_dir) if is_torch_available(): weights_path = save_directory / "pytorch_model.bin" torch.save(model.state_dict(), weights_path) elif is_tf_available(): weights_path = save_directory / "tf_model" / "weights" model.save_weights(str(weights_path)) config_path = save_directory / "config.json" # add model configuration model_config = model.cfg model_config["arch"] = arch model_config["task"] = task with config_path.open("w") as f: json.dump(model_config, f, indent=2, ensure_ascii=False)
def _predictor(arch: str, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor: if arch not in ARCHS: raise ValueError(f"unknown architecture '{arch}'") if arch not in ROT_ARCHS and not assume_straight_pages: raise AssertionError( "You are trying to use a model trained on straight pages while not assuming" " your pages are straight. If you have only straight documents, don't pass" f" assume_straight_pages=False, otherwise you should use one of these archs: {ROT_ARCHS}" ) # Detection _model = detection.__dict__[arch]( pretrained=pretrained, assume_straight_pages=assume_straight_pages) kwargs['mean'] = kwargs.get('mean', _model.cfg['mean']) kwargs['std'] = kwargs.get('std', _model.cfg['std']) kwargs['batch_size'] = kwargs.get('batch_size', 1) predictor = DetectionPredictor( PreProcessor( _model.cfg['input_shape'][:-1] if is_tf_available() else _model.cfg['input_shape'][1:], **kwargs), _model) return predictor
def _predictor(arch: str, pretrained: bool, **kwargs: Any) -> RecognitionPredictor: if arch not in ARCHS: raise ValueError(f"unknown architecture '{arch}'") _model = recognition.__dict__[arch](pretrained=pretrained) kwargs['mean'] = kwargs.get('mean', _model.cfg['mean']) kwargs['std'] = kwargs.get('std', _model.cfg['std']) kwargs['batch_size'] = kwargs.get('batch_size', 32) input_shape = _model.cfg['input_shape'][:2] if is_tf_available( ) else _model.cfg['input_shape'][-2:] predictor = RecognitionPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) return predictor
def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor: if arch not in ORIENTATION_ARCHS: raise ValueError(f"unknown architecture '{arch}'") # Load directly classifier from backbone _model = classification.__dict__[arch](pretrained=pretrained) kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 64) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available( ) else _model.cfg["input_shape"][1:] predictor = CropOrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model) return predictor
def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor: if isinstance(arch, str): if arch not in ARCHS: raise ValueError(f"unknown architecture '{arch}'") _model = recognition.__dict__[arch](pretrained=pretrained) else: if not isinstance( arch, (recognition.CRNN, recognition.SAR, recognition.MASTER)): raise ValueError(f"unknown architecture: {type(arch)}") _model = arch kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 32) input_shape = _model.cfg["input_shape"][:2] if is_tf_available( ) else _model.cfg["input_shape"][-2:] predictor = RecognitionPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) return predictor
from doctr.file_utils import is_tf_available from .generator import * from .cord import * from .detection import * from .doc_artefacts import * from .funsd import * from .ic03 import * from .ic13 import * from .iiit5k import * from .imgur5k import * from .ocr import * from .recognition import * from .sroie import * from .svhn import * from .svt import * from .synthtext import * from .utils import * from .vocabs import * if is_tf_available(): from .loader import *
def main(args): if not args.rotation: args.eval_straight = True predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size, assume_straight_pages=not args.rotation) if args.img_folder and args.label_file: testset = datasets.OCRDataset( img_folder=args.img_folder, label_file=args.label_file, ) sets = [testset] else: train_set = datasets.__dict__[args.dataset]( train=True, download=True, use_polygons=not args.eval_straight) val_set = datasets.__dict__[args.dataset]( train=False, download=True, use_polygons=not args.eval_straight) sets = [train_set, val_set] reco_metric = TextMatch() if args.mask_shape: det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape)) e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape)) else: det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight) e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight) sample_idx = 0 extraction_fn = extract_crops if args.eval_straight else extract_rcrops for dataset in sets: for page, target in tqdm(dataset): # GT gt_boxes = target['boxes'] gt_labels = target['labels'] if args.img_folder and args.label_file: x, y, w, h = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2], gt_boxes[:, 3] xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1) xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1) gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1) # Forward if is_tf_available(): out = predictor(page[None, ...]) crops = extraction_fn(page, gt_boxes) reco_out = predictor.reco_predictor(crops) else: with torch.no_grad(): out = predictor(page[None, ...]) # We directly crop on PyTorch tensors, which are in channels_first crops = extraction_fn(page, gt_boxes, channels_last=False) reco_out = predictor.reco_predictor(crops) if len(reco_out): reco_words, _ = zip(*reco_out) else: reco_words = [] # Unpack preds pred_boxes = [] pred_labels = [] for page in out.pages: height, width = page.dimensions for block in page.blocks: for line in block.lines: for word in line.words: if not args.rotation: (a, b), (c, d) = word.geometry else: [x1, y1], [x2, y2], [x3, y3], [x4, y4], = word.geometry if gt_boxes.dtype == int: if not args.rotation: pred_boxes.append([ int(a * width), int(b * height), int(c * width), int(d * height) ]) else: if args.eval_straight: pred_boxes.append([ int(width * min(x1, x2, x3, x4)), int(height * min(y1, y2, y3, y4)), int(width * max(x1, x2, x3, x4)), int(height * max(y1, y2, y3, y4)), ]) else: pred_boxes.append([ [ int(x1 * width), int(y1 * height) ], [ int(x2 * width), int(y2 * height) ], [ int(x3 * width), int(y3 * height) ], [ int(x4 * width), int(y4 * height) ], ]) else: if not args.rotation: pred_boxes.append([a, b, c, d]) else: if args.eval_straight: pred_boxes.append([ min(x1, x2, x3, x4), min(y1, y2, y3, y4), max(x1, x2, x3, x4), max(y1, y2, y3, y4), ]) else: pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) pred_labels.append(word.value) # Update the metric det_metric.update(gt_boxes, np.asarray(pred_boxes)) reco_metric.update(gt_labels, reco_words) e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels, pred_labels) # Loop break sample_idx += 1 if isinstance(args.samples, int) and args.samples == sample_idx: break if isinstance(args.samples, int) and args.samples == sample_idx: break # Unpack aggregated metrics print(f"Model Evaluation (model= {args.detection} + {args.recognition}, " f"dataset={'OCRDataset' if args.img_folder else args.dataset})") recall, precision, mean_iou = det_metric.summary() print( f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}" ) acc = reco_metric.summary() print( f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})" ) recall, precision, mean_iou = e2e_metric.summary() print( f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), " f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}" )
def build_target( self, target: List[np.ndarray], output_shape: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray]: if any(t.dtype != np.float32 for t in target): raise AssertionError( "the expected dtype of target 'boxes' entry is 'np.float32'.") if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target): raise ValueError( "the 'boxes' entry of the target is expected to take values between 0 & 1." ) h, w = output_shape target_shape = (len(target), h, w, 1) seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) for idx, _target in enumerate(target): # Draw each polygon on gt if _target.shape[0] == 0: # Empty image, full masked seg_mask[idx] = False # Absolute bounding boxes abs_boxes = _target.copy() if abs_boxes.ndim == 3: abs_boxes[:, :, 0] *= w abs_boxes[:, :, 1] *= h polys = abs_boxes boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) abs_boxes = np.concatenate( (abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) else: abs_boxes[:, [0, 2]] *= w abs_boxes[:, [1, 3]] *= h abs_boxes = abs_boxes.round().astype(np.int32) polys = np.stack( [ abs_boxes[:, [0, 1]], abs_boxes[:, [0, 3]], abs_boxes[:, [2, 3]], abs_boxes[:, [2, 1]], ], axis=1, ) boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) for poly, box, box_size in zip(polys, abs_boxes, boxes_size): # Mask boxes that are too small if box_size < self.min_size_box: seg_mask[idx, box[1]:box[3] + 1, box[0]:box[2] + 1] = False continue # Negative shrink for gt, as described in paper polygon = Polygon(poly) distance = polygon.area * ( 1 - np.power(self.shrink_ratio, 2)) / polygon.length subject = [tuple(coor) for coor in poly] padding = pyclipper.PyclipperOffset() padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) shrunken = padding.Execute(-distance) # Draw polygon on gt if it is valid if len(shrunken) == 0: seg_mask[idx, box[1]:box[3] + 1, box[0]:box[2] + 1] = False continue shrunken = np.array(shrunken[0]).reshape(-1, 2) if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, box[1]:box[3] + 1, box[0]:box[2] + 1] = False continue cv2.fillPoly(seg_target[idx], [shrunken.astype(np.int32)], 1) # Don't forget to switch back to channel first if PyTorch is used if not is_tf_available(): seg_target = seg_target.transpose(0, 3, 1, 2) seg_mask = seg_mask.transpose(0, 3, 1, 2) return seg_target, seg_mask
def main(det_archs, reco_archs): """Build a streamlit layout""" # Wide mode st.set_page_config(layout="wide") # Designing the interface st.title("docTR: Document Text Recognition") # For newline st.write("\n") # Instructions st.markdown( "*Hint: click on the top-right corner of an image to enlarge it!*") # Set the columns cols = st.columns((1, 1, 1, 1)) cols[0].subheader("Input page") cols[1].subheader("Segmentation heatmap") cols[2].subheader("OCR output") cols[3].subheader("Page reconstitution") # Sidebar # File selection st.sidebar.title("Document selection") # Disabling warning st.set_option("deprecation.showfileUploaderEncoding", False) # Choose your own image uploaded_file = st.sidebar.file_uploader( "Upload files", type=["pdf", "png", "jpeg", "jpg"]) if uploaded_file is not None: if uploaded_file.name.endswith(".pdf"): doc = DocumentFile.from_pdf(uploaded_file.read()) else: doc = DocumentFile.from_images(uploaded_file.read()) page_idx = st.sidebar.selectbox( "Page selection", [idx + 1 for idx in range(len(doc))]) - 1 page = doc[page_idx] cols[0].image(page) # Model selection st.sidebar.title("Model selection") st.sidebar.markdown("**Backend**: " + ("TensorFlow" if is_tf_available() else "PyTorch")) det_arch = st.sidebar.selectbox("Text detection model", det_archs) reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs) # For newline st.sidebar.write("\n") if st.sidebar.button("Analyze page"): if uploaded_file is None: st.sidebar.write("Please upload a document") else: with st.spinner("Loading model..."): predictor = load_predictor(det_arch, reco_arch, forward_device) with st.spinner("Analyzing..."): # Forward the image to the model seg_map = forward_image(predictor, page, forward_device) seg_map = np.squeeze(seg_map) seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR) # Plot the raw heatmap fig, ax = plt.subplots() ax.imshow(seg_map) ax.axis("off") cols[1].pyplot(fig) # Plot OCR output out = predictor([page]) fig = visualize_page(out.pages[0].export(), page, interactive=False) cols[2].pyplot(fig) # Page reconsitution under input page page_export = out.pages[0].export() if "rotation" not in det_arch: img = out.pages[0].synthesize() cols[3].image(img, clamp=True) # Display JSON st.markdown("\nHere are your analysis results in JSON format:") st.json(page_export)
from doctr.file_utils import is_tf_available, is_torch_available if not is_tf_available() and is_torch_available(): from .pytorch import *