示例#1
0
def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs):
    """
    Computes regions from a stack of heatmaps, a class mapping, and scaling
    factor.
    """
    logger.info('Vectorizing regions')
    regions = {}
    for region_type, idx in cls_map['regions'].items():
        logger.debug(f'Vectorizing regions of type {region_type}')
        regions[region_type] = vectorize_regions(heatmap[idx])
    for reg_id, regs in regions.items():
        regions[reg_id] = scale_regions(regs, scale)
    return regions
示例#2
0
文件: blla.py 项目: sixtyfive/kraken
def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float,
                **kwargs) -> Dict[str, List[List[Tuple[int, int]]]]:
    """
    Computes regions from a stack of heatmaps, a class mapping, and scaling
    factor.

    Args:
        heatmap: A stack of heatmaps of shape `NxHxW` output from the network.
        cls_map: Dictionary mapping string identifiers to indices on the stack
                 of heatmaps.
        scale: Scaling factor between heatmap and unscaled input image.

    Returns:
        A dictionary containing a key for each region type with a list of
        regions inside.
    """
    logger.info('Vectorizing regions')
    regions = {}
    for region_type, idx in cls_map['regions'].items():
        logger.debug(f'Vectorizing regions of type {region_type}')
        regions[region_type] = vectorize_regions(heatmap[idx])
    for reg_id, regs in regions.items():
        regions[reg_id] = scale_regions(regs, scale)
    return regions
示例#3
0
def segment(im,
            text_direction: str = 'horizontal-lr',
            mask: Optional[np.array] = None,
            reading_order_fn: Callable = polygonal_reading_order,
            model=None,
            device: str = 'cpu'):
    """
    Segments a page into text lines using the baseline segmenter.

    Segments a page into text lines and returns the polyline formed by each
    baseline and their estimated environment.

    Args:
        im (PIL.Image): An RGB image.
        text_direction (str): Ignored by the segmenter but kept for
                              serialization.
        mask (PIL.Image): A bi-level mask image of the same size as `im` where
                          0-valued regions are ignored for segmentation
                          purposes. Disables column detection.
        reading_order_fn (function): Function to determine the reading order.
                                     Has to accept a list of tuples (baselines,
                                     polygon) and a text direction (`lr` or
                                     `rl`).
        model (vgsl.TorchVGSLModel): A TorchVGSLModel containing a segmentation
                                     model. If none is given a default model
                                     will be loaded.
        device (str or torch.Device): The target device to run the neural
                                      network on.

    Returns:
        {'text_direction': '$dir',
         'type': 'baseline',
         'lines': [
            {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]},
            {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]}
          ]
          'regions': [
            {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'},
            {'region': [[x0, ...]], 'type': 'text'}
          ]
        }: A dictionary containing the text direction and under the key 'lines'
        a list of reading order sorted baselines (polylines) and their
        respective polygonal boundaries. The last and first point of each
        boundary polygon is connected.

    Raises:
        KrakenInputException if the input image is not binarized or the text
        direction is invalid.
    """
    im_str = get_im_str(im)
    logger.info(f'Segmenting {im_str}')

    if model is None:
        logger.info('No segmentation model given. Loading default model.')
        model = vgsl.TorchVGSLModel.load_model(pkg_resources.resource_filename(__name__, 'blla.mlmodel'))

    if model.one_channel_mode == '1' and not is_bitonal(im):
        logger.warning('Running binary model on non-binary input image '
                       '(mode {}). This will result in severely degraded '
                       'performance'.format(im.mode))

    model.eval()
    model.to(device)

    if mask:
        if mask.mode != '1' and not is_bitonal(mask):
            logger.error('Mask is not bitonal')
            raise KrakenInputException('Mask is not bitonal')
        mask = mask.convert('1')
        if mask.size != im.size:
            logger.error('Mask size {mask.size} doesn\'t match image size {im.size}')
            raise KrakenInputException('Mask size {mask.size} doesn\'t match image size {im.size}')
        logger.info('Masking enabled in segmenter.')
        mask = pil2array(mask)

    batch, channels, height, width = model.input
    transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False)
    res_tf = tf.Compose(transforms.transforms[:3])
    scal_im = res_tf(im).convert('L')

    with torch.no_grad():
        logger.debug('Running network forward pass')
        o = model.nn(transforms(im).unsqueeze(0).to(device))
    logger.debug('Upsampling network output')
    o = F.interpolate(o, size=scal_im.size[::-1])
    o = o.squeeze().cpu().numpy()
    scale = np.divide(im.size, o.shape[:0:-1])
    # postprocessing
    cls_map = model.user_metadata['class_mapping']
    st_sep = cls_map['aux']['_start_separator']
    end_sep = cls_map['aux']['_end_separator']

    logger.info('Vectorizing baselines')
    baselines = []
    regions = {}
    for bl_type, idx in cls_map['baselines'].items():
        logger.debug(f'Vectorizing lines of type {bl_type}')
        baselines.extend([(bl_type,x) for x in vectorize_lines(o[(st_sep, end_sep, idx), :, :])])
    logger.info('Vectorizing regions')
    for region_type, idx in cls_map['regions'].items():
        logger.debug(f'Vectorizing lines of type {bl_type}')
        regions[region_type] = vectorize_regions(o[idx])
    logger.debug('Polygonizing lines')
    lines = list(filter(lambda x: x[2] is not None, zip([x[0] for x in baselines],
                                                        [x[1] for x in baselines],
                                                        calculate_polygonal_environment(scal_im, [x[1] for x in baselines]))))
    logger.debug('Scaling vectorized lines')
    sc = scale_polygonal_lines([x[1:] for x in lines], scale)
    lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc]))
    logger.debug('Scaling vectorized regions')
    for reg_id, regs in regions.items():
        regions[reg_id] = scale_regions(regs, scale)
    logger.debug('Reordering baselines')
    order_regs = []
    for regs in regions.values():
        order_regs.extend(regs)
    lines = reading_order_fn(lines=lines, regions=order_regs, text_direction=text_direction[-2:])

    if 'class_mapping' in model.user_metadata and len(model.user_metadata['class_mapping']['baselines']) > 1:
        script_detection = True
    else:
        script_detection = False

    return {'text_direction': text_direction,
            'type': 'baselines',
            'lines': [{'script': bl_type, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines],
            'regions': regions,
            'script_detection': script_detection}
示例#4
0
文件: blla.py 项目: sixtyfive/kraken
def segment(im: PIL.Image.Image,
            text_direction: str = 'horizontal-lr',
            mask: Optional[np.ndarray] = None,
            reading_order_fn: Callable = polygonal_reading_order,
            model: Union[List[vgsl.TorchVGSLModel],
                         vgsl.TorchVGSLModel] = None,
            device: str = 'cpu') -> Dict[str, Any]:
    r"""
    Segments a page into text lines using the baseline segmenter.

    Segments a page into text lines and returns the polyline formed by each
    baseline and their estimated environment.

    Args:
        im: Input image. The mode can generally be anything but it is possible
            to supply a binarized-input-only model which requires accordingly
            treated images.
        text_direction: Passed-through value for serialization.serialize.
        mask: A bi-level mask image of the same size as `im` where 0-valued
              regions are ignored for segmentation purposes. Disables column
              detection.
        reading_order_fn: Function to determine the reading order.  Has to
                          accept a list of tuples (baselines, polygon) and a
                          text direction (`lr` or `rl`).
        model: One or more TorchVGSLModel containing a segmentation model. If
               none is given a default model will be loaded.
        device: The target device to run the neural network on.

    Returns:
        A dictionary containing the text direction and under the key 'lines' a
        list of reading order sorted baselines (polylines) and their respective
        polygonal boundaries. The last and first point of each boundary polygon
        are connected.

        .. code-block::
           :force:

            {'text_direction': '$dir',
             'type': 'baseline',
             'lines': [
                {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]},
                {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]}
              ]
              'regions': [
                {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'},
                {'region': [[x0, ...]], 'type': 'text'}
              ]
            }

    Raises:
        KrakenInvalidModelException: if the given model is not a valid
                                     segmentation model.
        KrakenInputException: if the mask is not bitonal or does not match the
                              image size.
    """
    if model is None:
        logger.info('No segmentation model given. Loading default model.')
        model = vgsl.TorchVGSLModel.load_model(
            pkg_resources.resource_filename(__name__, 'blla.mlmodel'))

    if isinstance(model, vgsl.TorchVGSLModel):
        model = [model]

    for nn in model:
        if nn.model_type != 'segmentation':
            raise KrakenInvalidModelException(
                f'Invalid model type {nn.model_type} for {nn}')
        if 'class_mapping' not in nn.user_metadata:
            raise KrakenInvalidModelException(
                f'Segmentation model {nn} does not contain valid class mapping'
            )

    im_str = get_im_str(im)
    logger.info(f'Segmenting {im_str}')

    for net in model:
        if 'topline' in net.user_metadata:
            loc = {
                None: 'center',
                True: 'top',
                False: 'bottom'
            }[net.user_metadata['topline']]
            logger.debug(f'Baseline location: {loc}')
        rets = compute_segmentation_map(im, mask, net, device)
        regions = vec_regions(**rets)
        # flatten regions for line ordering/fetch bounding regions
        line_regs = []
        suppl_obj = []
        for cls, regs in regions.items():
            line_regs.extend(regs)
            if rets['bounding_regions'] is not None and cls in rets[
                    'bounding_regions']:
                suppl_obj.extend(regs)
        # convert back to net scale
        suppl_obj = scale_regions(suppl_obj, 1 / rets['scale'])
        line_regs = scale_regions(line_regs, 1 / rets['scale'])
        lines = vec_lines(**rets,
                          regions=line_regs,
                          reading_order_fn=reading_order_fn,
                          text_direction=text_direction,
                          suppl_obj=suppl_obj,
                          topline=net.user_metadata['topline']
                          if 'topline' in net.user_metadata else False)

    if len(rets['cls_map']['baselines']) > 1:
        script_detection = True
    else:
        script_detection = False

    return {
        'text_direction': text_direction,
        'type': 'baselines',
        'lines': lines,
        'regions': regions,
        'script_detection': script_detection
    }