def compute_error(model: TorchSeqRecognizer, validation_set: Sequence[Tuple[str, str]]) -> Tuple[int, int]: """ Computes error report from a model and a list of line image-text pairs. Args: model (kraken.lib.models.TorchSeqRecognizer): Model used for recognition validation_set (list): List of tuples (image, text) for validation Returns: A tuple with total number of characters and edit distance across the whole validation set. """ total_chars = 0 error = 0 for im, text in validation_set: pred = model.predict_string(im) total_chars += len(text) error += _fast_levenshtein(pred, text) return total_chars, error
def compute_error(model: TorchSeqRecognizer, validation_set: Iterable[Dict[str, torch.Tensor]]) -> Tuple[int, int]: """ Computes error report from a model and a list of line image-text pairs. Args: model (kraken.lib.models.TorchSeqRecognizer): Model used for recognition validation_set (list): List of tuples (image, text) for validation Returns: A tuple with total number of characters and edit distance across the whole validation set. """ total_chars = 0 error = 0 for batch in validation_set: preds = model.predict_string(batch['image'], batch['seq_lens']) total_chars += batch['target_lens'].sum() for pred, text in zip(preds, batch['target']): error += _fast_levenshtein(pred, text) return total_chars, error
def rpred(network: TorchSeqRecognizer, im: Image.Image, bounds: dict, pad: int = 16, bidi_reordering: bool = True) -> Generator[ocr_record, None, None]: """ Uses a RNN to recognize text Args: network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer object im (PIL.Image.Image): Image to extract text from bounds (dict): A dictionary containing a 'boxes' entry with a list of coordinates (x0, y0, x1, y1) of a text line in the image and an entry 'text_direction' containing 'horizontal-lr/rl/vertical-lr/rl'. pad (int): Extra blank padding to the left and right of text line. Auto-disabled when expected network inputs are incompatible with padding. bidi_reordering (bool): Reorder classes in the ocr_record according to the Unicode bidirectional algorithm for correct display. Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ im_str = get_im_str(im) logger.info('Running recognizer on {} with {} lines'.format( im_str, len(bounds['boxes']))) logger.debug('Loading line transform') batch, channels, height, width = network.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) for box, coords in extract_boxes(im, bounds): # check if boxes are non-zero in any dimension if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0: logger.warning( 'bbox {} with zero dimension. Emitting empty record.'.format( coords)) yield ocr_record('', [], []) continue # try conversion into tensor try: line = ts(box) except Exception: yield ocr_record('', [], []) continue # check if line is non-zero if line.max() == line.min(): yield ocr_record('', [], []) continue preds = network.predict(line) # calculate recognized LSTM locations of characters # scale between network output and network input net_scale = line.shape[2] / network.outputs.shape[1] # scale between network input and original line in_scale = box.size[0] / (line.shape[2] - 2 * pad) def _scale_val(val, min_val, max_val): return int( round( min(max(((val * net_scale) - pad) * in_scale, min_val), max_val))) # XXX: fix bounding box calculation ocr_record for multi-codepoint labels. pred = ''.join(x[0] for x in preds) pos = [] conf = [] for _, start, end, c in preds: if bounds['text_direction'].startswith('horizontal'): xmin = coords[0] + _scale_val(start, 0, box.size[0]) xmax = coords[0] + _scale_val(end, 0, box.size[0]) pos.append((xmin, coords[1], xmax, coords[3])) else: ymin = coords[1] + _scale_val(start, 0, box.size[1]) ymax = coords[1] + _scale_val(start, 0, box.size[1]) pos.append((coords[0], ymin, coords[2], ymax)) conf.append(c) if bidi_reordering: logger.debug('BiDi reordering record.') yield bidi_record(ocr_record(pred, pos, conf)) else: logger.debug('Emitting raw record') yield ocr_record(pred, pos, conf)
def rpred(network: TorchSeqRecognizer, im: Image.Image, bounds: dict, pad: int = 16, bidi_reordering: bool = True) -> Generator[ocr_record, None, None]: """ Uses a RNN to recognize text Args: network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer object im (PIL.Image.Image): Image to extract text from bounds (dict): A dictionary containing a 'boxes' entry with a list of coordinates (x0, y0, x1, y1) of a text line in the image and an entry 'text_direction' containing 'horizontal-lr/rl/vertical-lr/rl'. pad (int): Extra blank padding to the left and right of text line. Auto-disabled when expected network inputs are incompatible with padding. bidi_reordering (bool): Reorder classes in the ocr_record according to the Unicode bidirectional algorithm for correct display. Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ im_str = get_im_str(im) logger.info('Running recognizer on {} with {} lines'.format(im_str, len(bounds['boxes']))) logger.debug('Loading line transform') batch, channels, height, width = network.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) for box, coords in extract_boxes(im, bounds): # check if boxes are non-zero in any dimension if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0: logger.warning('bbox {} with zero dimension. Emitting empty record.'.format(coords)) yield ocr_record('', [], []) continue # try conversion into tensor try: line = ts(box) except Exception: yield ocr_record('', [], []) continue # check if line is non-zero if line.max() == line.min(): yield ocr_record('', [], []) continue preds = network.predict(line) # calculate recognized LSTM locations of characters # scale between network output and network input net_scale = line.shape[2]/network.outputs.shape[1] # scale between network input and original line in_scale = box.size[0]/(line.shape[2]-2*pad) def _scale_val(val, min_val, max_val): return int(round(min(max(((val*net_scale)-pad)*in_scale, min_val), max_val))) # XXX: fix bounding box calculation ocr_record for multi-codepoint labels. pred = ''.join(x[0] for x in preds) pos = [] conf = [] for _, start, end, c in preds: if bounds['text_direction'].startswith('horizontal'): xmin = coords[0] + _scale_val(start, 0, box.size[0]) xmax = coords[0] + _scale_val(end, 0, box.size[0]) pos.append((xmin, coords[1], xmax, coords[3])) else: ymin = coords[1] + _scale_val(start, 0, box.size[1]) ymax = coords[1] + _scale_val(start, 0, box.size[1]) pos.append((coords[0], ymin, coords[2], ymax)) conf.append(c) if bidi_reordering: logger.debug('BiDi reordering record.') yield bidi_record(ocr_record(pred, pos, conf)) else: logger.debug('Emitting raw record') yield ocr_record(pred, pos, conf)