예제 #1
0
def compute_error(model: TorchSeqRecognizer, validation_set: Sequence[Tuple[str, str]]) -> Tuple[int, int]:
    """
    Computes error report from a model and a list of line image-text pairs.

    Args:
        model (kraken.lib.models.TorchSeqRecognizer): Model used for recognition
        validation_set (list): List of tuples (image, text) for validation

    Returns:
        A tuple with total number of characters and edit distance across the
        whole validation set.
    """
    total_chars = 0
    error = 0
    for im, text in validation_set:
        pred = model.predict_string(im)
        total_chars += len(text)
        error += _fast_levenshtein(pred, text)
    return total_chars, error
예제 #2
0
파일: dataset.py 프로젝트: maxnth/kraken
def compute_error(model: TorchSeqRecognizer, validation_set: Iterable[Dict[str, torch.Tensor]]) -> Tuple[int, int]:
    """
    Computes error report from a model and a list of line image-text pairs.

    Args:
        model (kraken.lib.models.TorchSeqRecognizer): Model used for recognition
        validation_set (list): List of tuples (image, text) for validation

    Returns:
        A tuple with total number of characters and edit distance across the
        whole validation set.
    """
    total_chars = 0
    error = 0
    for batch in validation_set:
        preds = model.predict_string(batch['image'], batch['seq_lens'])
        total_chars += batch['target_lens'].sum()
        for pred, text in zip(preds, batch['target']):
            error += _fast_levenshtein(pred, text)
    return total_chars, error
예제 #3
0
def rpred(network: TorchSeqRecognizer,
          im: Image.Image,
          bounds: dict,
          pad: int = 16,
          bidi_reordering: bool = True) -> Generator[ocr_record, None, None]:
    """
    Uses a RNN to recognize text

    Args:
        network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer
                                                        object
        im (PIL.Image.Image): Image to extract text from
        bounds (dict): A dictionary containing a 'boxes' entry with a list of
                       coordinates (x0, y0, x1, y1) of a text line in the image
                       and an entry 'text_direction' containing
                       'horizontal-lr/rl/vertical-lr/rl'.
        pad (int): Extra blank padding to the left and right of text line.
                   Auto-disabled when expected network inputs are incompatible
                   with padding.
        bidi_reordering (bool): Reorder classes in the ocr_record according to
                                the Unicode bidirectional algorithm for correct
                                display.
    Yields:
        An ocr_record containing the recognized text, absolute character
        positions, and confidence values for each character.
    """
    im_str = get_im_str(im)
    logger.info('Running recognizer on {} with {} lines'.format(
        im_str, len(bounds['boxes'])))
    logger.debug('Loading line transform')
    batch, channels, height, width = network.nn.input
    ts = generate_input_transforms(batch, height, width, channels, pad)

    for box, coords in extract_boxes(im, bounds):
        # check if boxes are non-zero in any dimension
        if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0:
            logger.warning(
                'bbox {} with zero dimension. Emitting empty record.'.format(
                    coords))
            yield ocr_record('', [], [])
            continue
        # try conversion into tensor
        try:
            line = ts(box)
        except Exception:
            yield ocr_record('', [], [])
            continue
        # check if line is non-zero
        if line.max() == line.min():
            yield ocr_record('', [], [])
            continue

        preds = network.predict(line)
        # calculate recognized LSTM locations of characters
        # scale between network output and network input
        net_scale = line.shape[2] / network.outputs.shape[1]
        # scale between network input and original line
        in_scale = box.size[0] / (line.shape[2] - 2 * pad)

        def _scale_val(val, min_val, max_val):
            return int(
                round(
                    min(max(((val * net_scale) - pad) * in_scale, min_val),
                        max_val)))

        # XXX: fix bounding box calculation ocr_record for multi-codepoint labels.
        pred = ''.join(x[0] for x in preds)
        pos = []
        conf = []
        for _, start, end, c in preds:
            if bounds['text_direction'].startswith('horizontal'):
                xmin = coords[0] + _scale_val(start, 0, box.size[0])
                xmax = coords[0] + _scale_val(end, 0, box.size[0])
                pos.append((xmin, coords[1], xmax, coords[3]))
            else:
                ymin = coords[1] + _scale_val(start, 0, box.size[1])
                ymax = coords[1] + _scale_val(start, 0, box.size[1])
                pos.append((coords[0], ymin, coords[2], ymax))
            conf.append(c)
        if bidi_reordering:
            logger.debug('BiDi reordering record.')
            yield bidi_record(ocr_record(pred, pos, conf))
        else:
            logger.debug('Emitting raw record')
            yield ocr_record(pred, pos, conf)
예제 #4
0
파일: rpred.py 프로젝트: mittagessen/kraken
def rpred(network: TorchSeqRecognizer,
          im: Image.Image,
          bounds: dict,
          pad: int = 16,
          bidi_reordering: bool = True) -> Generator[ocr_record, None, None]:
    """
    Uses a RNN to recognize text

    Args:
        network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer
                                                        object
        im (PIL.Image.Image): Image to extract text from
        bounds (dict): A dictionary containing a 'boxes' entry with a list of
                       coordinates (x0, y0, x1, y1) of a text line in the image
                       and an entry 'text_direction' containing
                       'horizontal-lr/rl/vertical-lr/rl'.
        pad (int): Extra blank padding to the left and right of text line.
                   Auto-disabled when expected network inputs are incompatible
                   with padding.
        bidi_reordering (bool): Reorder classes in the ocr_record according to
                                the Unicode bidirectional algorithm for correct
                                display.
    Yields:
        An ocr_record containing the recognized text, absolute character
        positions, and confidence values for each character.
    """
    im_str = get_im_str(im)
    logger.info('Running recognizer on {} with {} lines'.format(im_str, len(bounds['boxes'])))
    logger.debug('Loading line transform')
    batch, channels, height, width = network.nn.input
    ts = generate_input_transforms(batch, height, width, channels, pad)

    for box, coords in extract_boxes(im, bounds):
        # check if boxes are non-zero in any dimension
        if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0:
            logger.warning('bbox {} with zero dimension. Emitting empty record.'.format(coords))
            yield ocr_record('', [], [])
            continue
        # try conversion into tensor
        try:
            line = ts(box)
        except Exception:
            yield ocr_record('', [], [])
            continue
        # check if line is non-zero
        if line.max() == line.min():
            yield ocr_record('', [], [])
            continue

        preds = network.predict(line)
        # calculate recognized LSTM locations of characters
        # scale between network output and network input
        net_scale = line.shape[2]/network.outputs.shape[1]
        # scale between network input and original line
        in_scale = box.size[0]/(line.shape[2]-2*pad)

        def _scale_val(val, min_val, max_val):
            return int(round(min(max(((val*net_scale)-pad)*in_scale, min_val), max_val)))

        # XXX: fix bounding box calculation ocr_record for multi-codepoint labels.
        pred = ''.join(x[0] for x in preds)
        pos = []
        conf = []
        for _, start, end, c in preds:
            if bounds['text_direction'].startswith('horizontal'):
                xmin = coords[0] + _scale_val(start, 0, box.size[0])
                xmax = coords[0] + _scale_val(end, 0, box.size[0])
                pos.append((xmin, coords[1], xmax, coords[3]))
            else:
                ymin = coords[1] + _scale_val(start, 0, box.size[1])
                ymax = coords[1] + _scale_val(start, 0, box.size[1])
                pos.append((coords[0], ymin, coords[2], ymax))
            conf.append(c)
        if bidi_reordering:
            logger.debug('BiDi reordering record.')
            yield bidi_record(ocr_record(pred, pos, conf))
        else:
            logger.debug('Emitting raw record')
            yield ocr_record(pred, pos, conf)