def load_any(fname: str, train: bool = False, device: str = 'cpu') -> TorchSeqRecognizer: """ Loads anything that was, is, and will be a valid ocropus model and instantiates a shiny new kraken.lib.lstm.SeqRecognizer from the RNN configuration in the file. Currently it recognizes the following kinds of models: * protobuf models containing converted python BIDILSTMs (recognition only) * protobuf models containing CLSTM networks (recognition only) * protobuf models containing VGSL segmentation and recognitino networks. Additionally an attribute 'kind' will be added to the SeqRecognizer containing a string representation of the source kind. Current known values are: * pyrnn for pickled BIDILSTMs * clstm for protobuf models generated by clstm * vgsl for VGSL models Args: fname: Path to the model train: Enables gradient calculation and dropout layers in model. device: Target device Returns: A kraken.lib.models.TorchSeqRecognizer object. Raises: KrakenInvalidModelException: if the model is not loadable by any parser. """ nn = None kind = '' fname = abspath(expandvars(expanduser(fname))) logger.info('Loading model from {}'.format(fname)) try: nn = TorchVGSLModel.load_model(str(fname)) kind = 'vgsl' except Exception: try: nn = TorchVGSLModel.load_clstm_model(fname) kind = 'clstm' except Exception: try: nn = TorchVGSLModel.load_pronn_model(fname) kind = 'pronn' except Exception: pass if not nn: raise KrakenInvalidModelException( 'File {} not loadable by any parser.'.format(fname)) seq = TorchSeqRecognizer(nn, train=train, device=device) seq.kind = kind return seq
def load_any(fname: str, train: bool = False, device: str = 'cpu') -> TorchSeqRecognizer: """ Loads anything that was, is, and will be a valid ocropus model and instantiates a shiny new kraken.lib.lstm.SeqRecognizer from the RNN configuration in the file. Currently it recognizes the following kinds of models: * pyrnn models containing BIDILSTMs * protobuf models containing converted python BIDILSTMs * protobuf models containing CLSTM networks Additionally an attribute 'kind' will be added to the SeqRecognizer containing a string representation of the source kind. Current known values are: * pyrnn for pickled BIDILSTMs * clstm for protobuf models generated by clstm Args: fname (str): Path to the model train (bool): Enables gradient calculation and dropout layers in model. device (str): Target device Returns: A kraken.lib.models.TorchSeqRecognizer object. """ nn = None kind = '' fname = abspath(expandvars(expanduser(fname))) logger.info(u'Loading model from {}'.format(fname)) try: nn = TorchVGSLModel.load_model(str(fname)) kind = 'vgsl' except Exception: try: nn = TorchVGSLModel.load_clstm_model(fname) kind = 'clstm' except Exception: nn = TorchVGSLModel.load_pronn_model(fname) kind = 'pronn' try: nn = TorchVGSLModel.load_pyrnn_model(fname) kind = 'pyrnn' except Exception: pass if not nn: raise KrakenInvalidModelException('File {} not loadable by any parser.'.format(fname)) seq = TorchSeqRecognizer(nn, train=train, device=device) seq.kind = kind return seq
def segment(ctx, model, boxes, text_direction, scale, maxcolseps, black_colseps, remove_hlines, pad, mask): """ Segments page images into text lines. """ if model and boxes: logger.warning( f'Baseline model ({model}) given but legacy segmenter selected. Forcing to -bl.' ) boxes = False if boxes == False: if not model: model = SEGMENTATION_DEFAULT_MODEL from kraken.lib.vgsl import TorchVGSLModel message(f'Loading ANN {model}\t', nl=False) try: model = TorchVGSLModel.load_model(model) model.to(ctx.meta['device']) except Exception: message('\u2717', fg='red') ctx.exit(1) message('\u2713', fg='green') return partial(segmenter, boxes, model, text_direction, scale, maxcolseps, black_colseps, remove_hlines, pad, mask, ctx.meta['device'])
def compute_segmentation_map(im: PIL.Image.Image, mask: Optional[np.ndarray] = None, model: vgsl.TorchVGSLModel = None, device: str = 'cpu') -> Dict[str, Any]: """ Args: im: Input image mask: A bi-level mask array of the same size as `im` where 0-valued regions are ignored for segmentation purposes. Disables column detection. model: A TorchVGSLModel containing a segmentation model. device: The target device to run the neural network on. Returns: A dictionary containing the heatmaps ('heatmap', torch.Tensor), class map ('cls_map', Dict[str, Dict[str, int]]), the bounding regions for polygonization purposes ('bounding_regions', List[str]), the scale between the input image and the network output ('scale', float), and the scaled input image to the network ('scal_im', PIL.Image.Image). Raises: KrakenInputException: When given an invalid mask. """ im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') if model.input[ 1] == 1 and model.one_channel_mode == '1' and not is_bitonal(im): logger.warning('Running binary model on non-binary input image ' '(mode {}). This will result in severely degraded ' 'performance'.format(im.mode)) model.eval() model.to(device) batch, channels, height, width = model.input transforms = dataset.ImageInputTransforms(batch, height, width, channels, 0, valid_norm=False) tf_idx, _ = next( filter(lambda x: isinstance(x[1], tf.ToTensor), enumerate(transforms.transforms))) res_tf = tf.Compose(transforms.transforms[:tf_idx]) scal_im = np.array(res_tf(im).convert('L')) tensor_im = transforms(im) if mask: if mask.mode != '1' and not is_bitonal(mask): logger.error('Mask is not bitonal') raise KrakenInputException('Mask is not bitonal') mask = mask.convert('1') if mask.size != im.size: logger.error( 'Mask size {mask.size} doesn\'t match image size {im.size}') raise KrakenInputException( 'Mask size {mask.size} doesn\'t match image size {im.size}') logger.info('Masking enabled in segmenter.') tensor_im[~transforms(mask).bool()] = 0 with torch.no_grad(): logger.debug('Running network forward pass') o, _ = model.nn(tensor_im.unsqueeze(0).to(device)) logger.debug('Upsampling network output') o = F.interpolate(o, size=scal_im.shape) o = o.squeeze().cpu().numpy() scale = np.divide(im.size, o.shape[:0:-1]) bounding_regions = model.user_metadata[ 'bounding_regions'] if 'bounding_regions' in model.user_metadata else None return { 'heatmap': o, 'cls_map': model.user_metadata['class_mapping'], 'bounding_regions': bounding_regions, 'scale': scale, 'scal_im': scal_im }