def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs): """ Computes regions from a stack of heatmaps, a class mapping, and scaling factor. """ logger.info('Vectorizing regions') regions = {} for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing regions of type {region_type}') regions[region_type] = vectorize_regions(heatmap[idx]) for reg_id, regs in regions.items(): regions[reg_id] = scale_regions(regs, scale) return regions
def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> Dict[str, List[List[Tuple[int, int]]]]: """ Computes regions from a stack of heatmaps, a class mapping, and scaling factor. Args: heatmap: A stack of heatmaps of shape `NxHxW` output from the network. cls_map: Dictionary mapping string identifiers to indices on the stack of heatmaps. scale: Scaling factor between heatmap and unscaled input image. Returns: A dictionary containing a key for each region type with a list of regions inside. """ logger.info('Vectorizing regions') regions = {} for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing regions of type {region_type}') regions[region_type] = vectorize_regions(heatmap[idx]) for reg_id, regs in regions.items(): regions[reg_id] = scale_regions(regs, scale) return regions
def segment(im, text_direction: str = 'horizontal-lr', mask: Optional[np.array] = None, reading_order_fn: Callable = polygonal_reading_order, model=None, device: str = 'cpu'): """ Segments a page into text lines using the baseline segmenter. Segments a page into text lines and returns the polyline formed by each baseline and their estimated environment. Args: im (PIL.Image): An RGB image. text_direction (str): Ignored by the segmenter but kept for serialization. mask (PIL.Image): A bi-level mask image of the same size as `im` where 0-valued regions are ignored for segmentation purposes. Disables column detection. reading_order_fn (function): Function to determine the reading order. Has to accept a list of tuples (baselines, polygon) and a text direction (`lr` or `rl`). model (vgsl.TorchVGSLModel): A TorchVGSLModel containing a segmentation model. If none is given a default model will be loaded. device (str or torch.Device): The target device to run the neural network on. Returns: {'text_direction': '$dir', 'type': 'baseline', 'lines': [ {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} ] 'regions': [ {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, {'region': [[x0, ...]], 'type': 'text'} ] }: A dictionary containing the text direction and under the key 'lines' a list of reading order sorted baselines (polylines) and their respective polygonal boundaries. The last and first point of each boundary polygon is connected. Raises: KrakenInputException if the input image is not binarized or the text direction is invalid. """ im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') if model is None: logger.info('No segmentation model given. Loading default model.') model = vgsl.TorchVGSLModel.load_model(pkg_resources.resource_filename(__name__, 'blla.mlmodel')) if model.one_channel_mode == '1' and not is_bitonal(im): logger.warning('Running binary model on non-binary input image ' '(mode {}). This will result in severely degraded ' 'performance'.format(im.mode)) model.eval() model.to(device) if mask: if mask.mode != '1' and not is_bitonal(mask): logger.error('Mask is not bitonal') raise KrakenInputException('Mask is not bitonal') mask = mask.convert('1') if mask.size != im.size: logger.error('Mask size {mask.size} doesn\'t match image size {im.size}') raise KrakenInputException('Mask size {mask.size} doesn\'t match image size {im.size}') logger.info('Masking enabled in segmenter.') mask = pil2array(mask) batch, channels, height, width = model.input transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False) res_tf = tf.Compose(transforms.transforms[:3]) scal_im = res_tf(im).convert('L') with torch.no_grad(): logger.debug('Running network forward pass') o = model.nn(transforms(im).unsqueeze(0).to(device)) logger.debug('Upsampling network output') o = F.interpolate(o, size=scal_im.size[::-1]) o = o.squeeze().cpu().numpy() scale = np.divide(im.size, o.shape[:0:-1]) # postprocessing cls_map = model.user_metadata['class_mapping'] st_sep = cls_map['aux']['_start_separator'] end_sep = cls_map['aux']['_end_separator'] logger.info('Vectorizing baselines') baselines = [] regions = {} for bl_type, idx in cls_map['baselines'].items(): logger.debug(f'Vectorizing lines of type {bl_type}') baselines.extend([(bl_type,x) for x in vectorize_lines(o[(st_sep, end_sep, idx), :, :])]) logger.info('Vectorizing regions') for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing lines of type {bl_type}') regions[region_type] = vectorize_regions(o[idx]) logger.debug('Polygonizing lines') lines = list(filter(lambda x: x[2] is not None, zip([x[0] for x in baselines], [x[1] for x in baselines], calculate_polygonal_environment(scal_im, [x[1] for x in baselines])))) logger.debug('Scaling vectorized lines') sc = scale_polygonal_lines([x[1:] for x in lines], scale) lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc])) logger.debug('Scaling vectorized regions') for reg_id, regs in regions.items(): regions[reg_id] = scale_regions(regs, scale) logger.debug('Reordering baselines') order_regs = [] for regs in regions.values(): order_regs.extend(regs) lines = reading_order_fn(lines=lines, regions=order_regs, text_direction=text_direction[-2:]) if 'class_mapping' in model.user_metadata and len(model.user_metadata['class_mapping']['baselines']) > 1: script_detection = True else: script_detection = False return {'text_direction': text_direction, 'type': 'baselines', 'lines': [{'script': bl_type, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines], 'regions': regions, 'script_detection': script_detection}