Exemplo n.º 1
0
def test_ocrdataset(mock_ocrdataset, use_polygons):

    input_size = (512, 512)

    ds = datasets.OCRDataset(
        *mock_ocrdataset,
        img_transforms=Resize(input_size),
        use_polygons=use_polygons,
    )
    assert len(ds) == 3
    _validate_dataset(ds, input_size, is_polygons=use_polygons)

    # File existence check
    img_name, _ = ds.data[0]
    move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file"))
    with pytest.raises(FileNotFoundError):
        datasets.OCRDataset(*mock_ocrdataset)
    move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name))
Exemplo n.º 2
0
def main(args):

    if not args.rotation:
        args.eval_straight = True

    predictor = ocr_predictor(args.detection,
                              args.recognition,
                              pretrained=True,
                              reco_bs=args.batch_size,
                              assume_straight_pages=not args.rotation)

    if args.img_folder and args.label_file:
        testset = datasets.OCRDataset(
            img_folder=args.img_folder,
            label_file=args.label_file,
        )
        sets = [testset]
    else:
        train_set = datasets.__dict__[args.dataset](
            train=True, download=True, use_polygons=not args.eval_straight)
        val_set = datasets.__dict__[args.dataset](
            train=False, download=True, use_polygons=not args.eval_straight)
        sets = [train_set, val_set]

    reco_metric = TextMatch()
    if args.mask_shape:
        det_metric = LocalizationConfusion(iou_thresh=args.iou,
                                           use_polygons=not args.eval_straight,
                                           mask_shape=(args.mask_shape,
                                                       args.mask_shape))
        e2e_metric = OCRMetric(iou_thresh=args.iou,
                               use_polygons=not args.eval_straight,
                               mask_shape=(args.mask_shape, args.mask_shape))
    else:
        det_metric = LocalizationConfusion(iou_thresh=args.iou,
                                           use_polygons=not args.eval_straight)
        e2e_metric = OCRMetric(iou_thresh=args.iou,
                               use_polygons=not args.eval_straight)

    sample_idx = 0
    extraction_fn = extract_crops if args.eval_straight else extract_rcrops

    for dataset in sets:
        for page, target in tqdm(dataset):
            # GT
            gt_boxes = target['boxes']
            gt_labels = target['labels']

            if args.img_folder and args.label_file:
                x, y, w, h = gt_boxes[:,
                                      0], gt_boxes[:,
                                                   1], gt_boxes[:,
                                                                2], gt_boxes[:,
                                                                             3]
                xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1)
                xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1)
                gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1)

            # Forward
            if is_tf_available():
                out = predictor(page[None, ...])
                crops = extraction_fn(page, gt_boxes)
                reco_out = predictor.reco_predictor(crops)
            else:
                with torch.no_grad():
                    out = predictor(page[None, ...])
                    # We directly crop on PyTorch tensors, which are in channels_first
                    crops = extraction_fn(page, gt_boxes, channels_last=False)
                    reco_out = predictor.reco_predictor(crops)

            if len(reco_out):
                reco_words, _ = zip(*reco_out)
            else:
                reco_words = []

            # Unpack preds
            pred_boxes = []
            pred_labels = []
            for page in out.pages:
                height, width = page.dimensions
                for block in page.blocks:
                    for line in block.lines:
                        for word in line.words:
                            if not args.rotation:
                                (a, b), (c, d) = word.geometry
                            else:
                                [x1, y1], [x2, y2], [x3,
                                                     y3], [x4,
                                                           y4], = word.geometry
                            if gt_boxes.dtype == int:
                                if not args.rotation:
                                    pred_boxes.append([
                                        int(a * width),
                                        int(b * height),
                                        int(c * width),
                                        int(d * height)
                                    ])
                                else:
                                    if args.eval_straight:
                                        pred_boxes.append([
                                            int(width * min(x1, x2, x3, x4)),
                                            int(height * min(y1, y2, y3, y4)),
                                            int(width * max(x1, x2, x3, x4)),
                                            int(height * max(y1, y2, y3, y4)),
                                        ])
                                    else:
                                        pred_boxes.append([
                                            [
                                                int(x1 * width),
                                                int(y1 * height)
                                            ],
                                            [
                                                int(x2 * width),
                                                int(y2 * height)
                                            ],
                                            [
                                                int(x3 * width),
                                                int(y3 * height)
                                            ],
                                            [
                                                int(x4 * width),
                                                int(y4 * height)
                                            ],
                                        ])
                            else:
                                if not args.rotation:
                                    pred_boxes.append([a, b, c, d])
                                else:
                                    if args.eval_straight:
                                        pred_boxes.append([
                                            min(x1, x2, x3, x4),
                                            min(y1, y2, y3, y4),
                                            max(x1, x2, x3, x4),
                                            max(y1, y2, y3, y4),
                                        ])
                                    else:
                                        pred_boxes.append([[x1, y1], [x2, y2],
                                                           [x3, y3], [x4, y4]])
                            pred_labels.append(word.value)

            # Update the metric
            det_metric.update(gt_boxes, np.asarray(pred_boxes))
            reco_metric.update(gt_labels, reco_words)
            e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels,
                              pred_labels)

            # Loop break
            sample_idx += 1
            if isinstance(args.samples, int) and args.samples == sample_idx:
                break
        if isinstance(args.samples, int) and args.samples == sample_idx:
            break

    # Unpack aggregated metrics
    print(f"Model Evaluation (model= {args.detection} + {args.recognition}, "
          f"dataset={'OCRDataset' if args.img_folder else args.dataset})")
    recall, precision, mean_iou = det_metric.summary()
    print(
        f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}"
    )
    acc = reco_metric.summary()
    print(
        f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})"
    )
    recall, precision, mean_iou = e2e_metric.summary()
    print(
        f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), "
        f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}"
    )