def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs): """ Computes regions from a stack of heatmaps, a class mapping, and scaling factor. """ logger.info('Vectorizing regions') regions = {} for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing regions of type {region_type}') regions[region_type] = vectorize_regions(heatmap[idx]) for reg_id, regs in regions.items(): regions[reg_id] = scale_regions(regs, scale) return regions
def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> Dict[str, List[List[Tuple[int, int]]]]: """ Computes regions from a stack of heatmaps, a class mapping, and scaling factor. Args: heatmap: A stack of heatmaps of shape `NxHxW` output from the network. cls_map: Dictionary mapping string identifiers to indices on the stack of heatmaps. scale: Scaling factor between heatmap and unscaled input image. Returns: A dictionary containing a key for each region type with a list of regions inside. """ logger.info('Vectorizing regions') regions = {} for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing regions of type {region_type}') regions[region_type] = vectorize_regions(heatmap[idx]) for reg_id, regs in regions.items(): regions[reg_id] = scale_regions(regs, scale) return regions
def segment(im, text_direction: str = 'horizontal-lr', mask: Optional[np.array] = None, reading_order_fn: Callable = polygonal_reading_order, model=None, device: str = 'cpu'): """ Segments a page into text lines using the baseline segmenter. Segments a page into text lines and returns the polyline formed by each baseline and their estimated environment. Args: im (PIL.Image): An RGB image. text_direction (str): Ignored by the segmenter but kept for serialization. mask (PIL.Image): A bi-level mask image of the same size as `im` where 0-valued regions are ignored for segmentation purposes. Disables column detection. reading_order_fn (function): Function to determine the reading order. Has to accept a list of tuples (baselines, polygon) and a text direction (`lr` or `rl`). model (vgsl.TorchVGSLModel): A TorchVGSLModel containing a segmentation model. If none is given a default model will be loaded. device (str or torch.Device): The target device to run the neural network on. Returns: {'text_direction': '$dir', 'type': 'baseline', 'lines': [ {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} ] 'regions': [ {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, {'region': [[x0, ...]], 'type': 'text'} ] }: A dictionary containing the text direction and under the key 'lines' a list of reading order sorted baselines (polylines) and their respective polygonal boundaries. The last and first point of each boundary polygon is connected. Raises: KrakenInputException if the input image is not binarized or the text direction is invalid. """ im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') if model is None: logger.info('No segmentation model given. Loading default model.') model = vgsl.TorchVGSLModel.load_model(pkg_resources.resource_filename(__name__, 'blla.mlmodel')) if model.one_channel_mode == '1' and not is_bitonal(im): logger.warning('Running binary model on non-binary input image ' '(mode {}). This will result in severely degraded ' 'performance'.format(im.mode)) model.eval() model.to(device) if mask: if mask.mode != '1' and not is_bitonal(mask): logger.error('Mask is not bitonal') raise KrakenInputException('Mask is not bitonal') mask = mask.convert('1') if mask.size != im.size: logger.error('Mask size {mask.size} doesn\'t match image size {im.size}') raise KrakenInputException('Mask size {mask.size} doesn\'t match image size {im.size}') logger.info('Masking enabled in segmenter.') mask = pil2array(mask) batch, channels, height, width = model.input transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False) res_tf = tf.Compose(transforms.transforms[:3]) scal_im = res_tf(im).convert('L') with torch.no_grad(): logger.debug('Running network forward pass') o = model.nn(transforms(im).unsqueeze(0).to(device)) logger.debug('Upsampling network output') o = F.interpolate(o, size=scal_im.size[::-1]) o = o.squeeze().cpu().numpy() scale = np.divide(im.size, o.shape[:0:-1]) # postprocessing cls_map = model.user_metadata['class_mapping'] st_sep = cls_map['aux']['_start_separator'] end_sep = cls_map['aux']['_end_separator'] logger.info('Vectorizing baselines') baselines = [] regions = {} for bl_type, idx in cls_map['baselines'].items(): logger.debug(f'Vectorizing lines of type {bl_type}') baselines.extend([(bl_type,x) for x in vectorize_lines(o[(st_sep, end_sep, idx), :, :])]) logger.info('Vectorizing regions') for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing lines of type {bl_type}') regions[region_type] = vectorize_regions(o[idx]) logger.debug('Polygonizing lines') lines = list(filter(lambda x: x[2] is not None, zip([x[0] for x in baselines], [x[1] for x in baselines], calculate_polygonal_environment(scal_im, [x[1] for x in baselines])))) logger.debug('Scaling vectorized lines') sc = scale_polygonal_lines([x[1:] for x in lines], scale) lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc])) logger.debug('Scaling vectorized regions') for reg_id, regs in regions.items(): regions[reg_id] = scale_regions(regs, scale) logger.debug('Reordering baselines') order_regs = [] for regs in regions.values(): order_regs.extend(regs) lines = reading_order_fn(lines=lines, regions=order_regs, text_direction=text_direction[-2:]) if 'class_mapping' in model.user_metadata and len(model.user_metadata['class_mapping']['baselines']) > 1: script_detection = True else: script_detection = False return {'text_direction': text_direction, 'type': 'baselines', 'lines': [{'script': bl_type, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines], 'regions': regions, 'script_detection': script_detection}
def segment(im: PIL.Image.Image, text_direction: str = 'horizontal-lr', mask: Optional[np.ndarray] = None, reading_order_fn: Callable = polygonal_reading_order, model: Union[List[vgsl.TorchVGSLModel], vgsl.TorchVGSLModel] = None, device: str = 'cpu') -> Dict[str, Any]: r""" Segments a page into text lines using the baseline segmenter. Segments a page into text lines and returns the polyline formed by each baseline and their estimated environment. Args: im: Input image. The mode can generally be anything but it is possible to supply a binarized-input-only model which requires accordingly treated images. text_direction: Passed-through value for serialization.serialize. mask: A bi-level mask image of the same size as `im` where 0-valued regions are ignored for segmentation purposes. Disables column detection. reading_order_fn: Function to determine the reading order. Has to accept a list of tuples (baselines, polygon) and a text direction (`lr` or `rl`). model: One or more TorchVGSLModel containing a segmentation model. If none is given a default model will be loaded. device: The target device to run the neural network on. Returns: A dictionary containing the text direction and under the key 'lines' a list of reading order sorted baselines (polylines) and their respective polygonal boundaries. The last and first point of each boundary polygon are connected. .. code-block:: :force: {'text_direction': '$dir', 'type': 'baseline', 'lines': [ {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} ] 'regions': [ {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, {'region': [[x0, ...]], 'type': 'text'} ] } Raises: KrakenInvalidModelException: if the given model is not a valid segmentation model. KrakenInputException: if the mask is not bitonal or does not match the image size. """ if model is None: logger.info('No segmentation model given. Loading default model.') model = vgsl.TorchVGSLModel.load_model( pkg_resources.resource_filename(__name__, 'blla.mlmodel')) if isinstance(model, vgsl.TorchVGSLModel): model = [model] for nn in model: if nn.model_type != 'segmentation': raise KrakenInvalidModelException( f'Invalid model type {nn.model_type} for {nn}') if 'class_mapping' not in nn.user_metadata: raise KrakenInvalidModelException( f'Segmentation model {nn} does not contain valid class mapping' ) im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') for net in model: if 'topline' in net.user_metadata: loc = { None: 'center', True: 'top', False: 'bottom' }[net.user_metadata['topline']] logger.debug(f'Baseline location: {loc}') rets = compute_segmentation_map(im, mask, net, device) regions = vec_regions(**rets) # flatten regions for line ordering/fetch bounding regions line_regs = [] suppl_obj = [] for cls, regs in regions.items(): line_regs.extend(regs) if rets['bounding_regions'] is not None and cls in rets[ 'bounding_regions']: suppl_obj.extend(regs) # convert back to net scale suppl_obj = scale_regions(suppl_obj, 1 / rets['scale']) line_regs = scale_regions(line_regs, 1 / rets['scale']) lines = vec_lines(**rets, regions=line_regs, reading_order_fn=reading_order_fn, text_direction=text_direction, suppl_obj=suppl_obj, topline=net.user_metadata['topline'] if 'topline' in net.user_metadata else False) if len(rets['cls_map']['baselines']) > 1: script_detection = True else: script_detection = False return { 'text_direction': text_direction, 'type': 'baselines', 'lines': lines, 'regions': regions, 'script_detection': script_detection }