Ejemplo n.º 1
0
def load_any(fname: str,
             train: bool = False,
             device: str = 'cpu') -> TorchSeqRecognizer:
    """
    Loads anything that was, is, and will be a valid ocropus model and
    instantiates a shiny new kraken.lib.lstm.SeqRecognizer from the RNN
    configuration in the file.

    Currently it recognizes the following kinds of models:

        * protobuf models containing converted python BIDILSTMs (recognition
          only)
        * protobuf models containing CLSTM networks (recognition only)
        * protobuf models containing VGSL segmentation and recognitino
          networks.

    Additionally an attribute 'kind' will be added to the SeqRecognizer
    containing a string representation of the source kind. Current known values
    are:

        * pyrnn for pickled BIDILSTMs
        * clstm for protobuf models generated by clstm
        * vgsl for VGSL models

    Args:
        fname: Path to the model
        train: Enables gradient calculation and dropout layers in model.
        device: Target device

    Returns:
        A kraken.lib.models.TorchSeqRecognizer object.

    Raises:
        KrakenInvalidModelException: if the model is not loadable by any parser.
    """
    nn = None
    kind = ''
    fname = abspath(expandvars(expanduser(fname)))
    logger.info('Loading model from {}'.format(fname))
    try:
        nn = TorchVGSLModel.load_model(str(fname))
        kind = 'vgsl'
    except Exception:
        try:
            nn = TorchVGSLModel.load_clstm_model(fname)
            kind = 'clstm'
        except Exception:
            try:
                nn = TorchVGSLModel.load_pronn_model(fname)
                kind = 'pronn'
            except Exception:
                pass
    if not nn:
        raise KrakenInvalidModelException(
            'File {} not loadable by any parser.'.format(fname))
    seq = TorchSeqRecognizer(nn, train=train, device=device)
    seq.kind = kind
    return seq
Ejemplo n.º 2
0
def load_any(fname: str, train: bool = False, device: str = 'cpu') -> TorchSeqRecognizer:
    """
    Loads anything that was, is, and will be a valid ocropus model and
    instantiates a shiny new kraken.lib.lstm.SeqRecognizer from the RNN
    configuration in the file.

    Currently it recognizes the following kinds of models:

        * pyrnn models containing BIDILSTMs
        * protobuf models containing converted python BIDILSTMs
        * protobuf models containing CLSTM networks

    Additionally an attribute 'kind' will be added to the SeqRecognizer
    containing a string representation of the source kind. Current known values
    are:

        * pyrnn for pickled BIDILSTMs
        * clstm for protobuf models generated by clstm

    Args:
        fname (str): Path to the model
        train (bool): Enables gradient calculation and dropout layers in model.
        device (str): Target device

    Returns:
        A kraken.lib.models.TorchSeqRecognizer object.
    """
    nn = None
    kind = ''
    fname = abspath(expandvars(expanduser(fname)))
    logger.info(u'Loading model from {}'.format(fname))
    try:
        nn = TorchVGSLModel.load_model(str(fname))
        kind = 'vgsl'
    except Exception:
        try:
            nn = TorchVGSLModel.load_clstm_model(fname)
            kind = 'clstm'
        except Exception:
            nn = TorchVGSLModel.load_pronn_model(fname)
            kind = 'pronn'
        try:
            nn = TorchVGSLModel.load_pyrnn_model(fname)
            kind = 'pyrnn'
        except Exception:
            pass
    if not nn:
        raise KrakenInvalidModelException('File {} not loadable by any parser.'.format(fname))
    seq = TorchSeqRecognizer(nn, train=train, device=device)
    seq.kind = kind
    return seq
Ejemplo n.º 3
0
def segment(ctx, model, boxes, text_direction, scale, maxcolseps,
            black_colseps, remove_hlines, pad, mask):
    """
    Segments page images into text lines.
    """
    if model and boxes:
        logger.warning(
            f'Baseline model ({model}) given but legacy segmenter selected. Forcing to -bl.'
        )
        boxes = False

    if boxes == False:
        if not model:
            model = SEGMENTATION_DEFAULT_MODEL
        from kraken.lib.vgsl import TorchVGSLModel
        message(f'Loading ANN {model}\t', nl=False)
        try:
            model = TorchVGSLModel.load_model(model)
            model.to(ctx.meta['device'])
        except Exception:
            message('\u2717', fg='red')
            ctx.exit(1)
        message('\u2713', fg='green')

    return partial(segmenter, boxes, model, text_direction, scale, maxcolseps,
                   black_colseps, remove_hlines, pad, mask, ctx.meta['device'])
Ejemplo n.º 4
0
def compute_segmentation_map(im: PIL.Image.Image,
                             mask: Optional[np.ndarray] = None,
                             model: vgsl.TorchVGSLModel = None,
                             device: str = 'cpu') -> Dict[str, Any]:
    """
    Args:
        im: Input image
        mask: A bi-level mask array of the same size as `im` where 0-valued
              regions are ignored for segmentation purposes. Disables column
              detection.
        model: A TorchVGSLModel containing a segmentation model.
        device: The target device to run the neural network on.

    Returns:
        A dictionary containing the heatmaps ('heatmap', torch.Tensor), class
        map ('cls_map', Dict[str, Dict[str, int]]), the bounding regions for
        polygonization purposes ('bounding_regions', List[str]), the scale
        between the input image and the network output ('scale', float), and
        the scaled input image to the network ('scal_im', PIL.Image.Image).

    Raises:
        KrakenInputException: When given an invalid mask.
    """
    im_str = get_im_str(im)
    logger.info(f'Segmenting {im_str}')

    if model.input[
            1] == 1 and model.one_channel_mode == '1' and not is_bitonal(im):
        logger.warning('Running binary model on non-binary input image '
                       '(mode {}). This will result in severely degraded '
                       'performance'.format(im.mode))

    model.eval()
    model.to(device)

    batch, channels, height, width = model.input
    transforms = dataset.ImageInputTransforms(batch,
                                              height,
                                              width,
                                              channels,
                                              0,
                                              valid_norm=False)
    tf_idx, _ = next(
        filter(lambda x: isinstance(x[1], tf.ToTensor),
               enumerate(transforms.transforms)))
    res_tf = tf.Compose(transforms.transforms[:tf_idx])
    scal_im = np.array(res_tf(im).convert('L'))

    tensor_im = transforms(im)
    if mask:
        if mask.mode != '1' and not is_bitonal(mask):
            logger.error('Mask is not bitonal')
            raise KrakenInputException('Mask is not bitonal')
        mask = mask.convert('1')
        if mask.size != im.size:
            logger.error(
                'Mask size {mask.size} doesn\'t match image size {im.size}')
            raise KrakenInputException(
                'Mask size {mask.size} doesn\'t match image size {im.size}')
        logger.info('Masking enabled in segmenter.')
        tensor_im[~transforms(mask).bool()] = 0

    with torch.no_grad():
        logger.debug('Running network forward pass')
        o, _ = model.nn(tensor_im.unsqueeze(0).to(device))
    logger.debug('Upsampling network output')
    o = F.interpolate(o, size=scal_im.shape)
    o = o.squeeze().cpu().numpy()
    scale = np.divide(im.size, o.shape[:0:-1])
    bounding_regions = model.user_metadata[
        'bounding_regions'] if 'bounding_regions' in model.user_metadata else None
    return {
        'heatmap': o,
        'cls_map': model.user_metadata['class_mapping'],
        'bounding_regions': bounding_regions,
        'scale': scale,
        'scal_im': scal_im
    }