def load_clstm(fname): """ Loads a CLSTM model in protobuf format and instantiates an object implementing the kraken.lib.SeqRecognizer interface. Args: fname (str): Path to the protobuf file Returns: A SeqRecognizer object Raises: KrakenInvalidModelException if no clstm module is available or the model is broken. """ logger.info(u'Trying to load clstm model from {}'.format(fname)) try: import clstm except ImportError: logger.debug(u'No clstm module available') raise KrakenInvalidModelException('No clstm module available') try: return ClstmSeqRecognizer(fname) except Exception as e: logger.debug(u'Loading clstm model failed.') raise KrakenInvalidModelException(str(e))
def load_model(cls, path: str): """ Deserializes a VGSL model from a CoreML file. Args: path (str): CoreML file Returns: A TorchVGSLModel instance. Raises: KrakenInvalidModelException if the model data is invalid (not a string, protobuf file, or without appropriate metadata). FileNotFoundError if the path doesn't point to a file. """ try: mlmodel = MLModel(path) except TypeError as e: raise KrakenInvalidModelException(str(e)) except DecodeError as e: raise KrakenInvalidModelException('Failure parsing model protobuf: {}'.format(str(e))) if 'vgsl' not in mlmodel.user_defined_metadata: raise KrakenInvalidModelException('No VGSL spec in model metadata') vgsl_spec = mlmodel.user_defined_metadata['vgsl'] nn = cls(vgsl_spec) for name, layer in nn.nn.named_children(): layer.deserialize(name, mlmodel.get_spec()) if 'codec' in mlmodel.user_defined_metadata: nn.add_codec(PytorchCodec(json.loads(mlmodel.user_defined_metadata['codec']))) nn.user_metadata = {'accuracy': [], 'seg_type': 'bbox', 'one_channel_mode': '1', 'model_type': None, 'hyper_params': {}} # type: dict[str, str] if 'kraken_meta' in mlmodel.user_defined_metadata: nn.user_metadata.update(json.loads(mlmodel.user_defined_metadata['kraken_meta'])) return nn
def load_pronn_model(cls, path: str): """ Loads an pronn model to VGSL. """ with open(path, 'rb') as fp: net = pyrnn_pb2.pyrnn() try: net.ParseFromString(fp.read()) except Exception: raise KrakenInvalidModelException('File does not contain valid proto msg') if not net.IsInitialized(): raise KrakenInvalidModelException('Model incomplete') # extract codec codec = PytorchCodec(net.codec) input = net.ninput hidden = net.fwdnet.wgi.dim[0] # extract weights weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop') fwd_w = [] rev_w = [] for w in weightnames: fwd_ar = getattr(net.fwdnet, w) rev_ar = getattr(net.revnet, w) fwd_w.append(torch.Tensor(fwd_ar.value).view(list(fwd_ar.dim))) rev_w.append(torch.Tensor(rev_ar.value).view(list(rev_ar.dim))) t = torch.cat(fwd_w[:4]) weight_ih_l0 = t[:, :input+1] weight_hh_l0 = t[:, input+1:] t = torch.cat(rev_w[:4]) weight_ih_l0_rev = t[:, :input+1] weight_hh_l0_rev = t[:, input+1:] weight_lin = torch.Tensor(net.softmax.w2.value).view(list(net.softmax.w2.dim)) # build vgsl spec and set weights nn = cls('[1,1,0,{} Lbxo{} O1ca{}]'.format(input, hidden, len(net.codec))) nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0) nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0) nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev) nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev) nn.nn.L_0.layer.weight_ip_l0 = torch.nn.Parameter(fwd_w[4]) nn.nn.L_0.layer.weight_fp_l0 = torch.nn.Parameter(fwd_w[5]) nn.nn.L_0.layer.weight_op_l0 = torch.nn.Parameter(fwd_w[6]) nn.nn.L_0.layer.weight_ip_l0_reverse = torch.nn.Parameter(rev_w[4]) nn.nn.L_0.layer.weight_fp_l0_reverse = torch.nn.Parameter(rev_w[5]) nn.nn.L_0.layer.weight_op_l0_reverse = torch.nn.Parameter(rev_w[6]) nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin) nn.add_codec(codec) return nn
def load_model(cls, path: Union[str, pathlib.Path]): """ Deserializes a VGSL model from a CoreML file. Args: path: CoreML file Returns: A TorchVGSLModel instance. Raises: KrakenInvalidModelException if the model data is invalid (not a string, protobuf file, or without appropriate metadata). FileNotFoundError if the path doesn't point to a file. """ if isinstance(path, pathlib.Path): path = path.as_posix() try: mlmodel = MLModel(path) except TypeError as e: raise KrakenInvalidModelException(str(e)) except DecodeError as e: raise KrakenInvalidModelException( 'Failure parsing model protobuf: {}'.format(str(e))) if 'vgsl' not in mlmodel.user_defined_metadata: raise KrakenInvalidModelException('No VGSL spec in model metadata') vgsl_spec = mlmodel.user_defined_metadata['vgsl'] nn = cls(vgsl_spec) def _deserialize_layers(name, layer): logger.debug(f'Deserializing layer {name} with type {type(layer)}') if type(layer) in (layers.MultiParamParallel, layers.MultiParamSequential): for name, l in layer.named_children(): _deserialize_layers(name, l) else: layer.deserialize(name, mlmodel.get_spec()) _deserialize_layers('', nn.nn) if 'codec' in mlmodel.user_defined_metadata: nn.add_codec( PytorchCodec(json.loads( mlmodel.user_defined_metadata['codec']))) nn.user_metadata = { 'accuracy': [], 'seg_type': 'bbox', 'one_channel_mode': '1', 'model_type': None, 'hyper_params': {} } # type: dict[str, str] if 'kraken_meta' in mlmodel.user_defined_metadata: nn.user_metadata.update( json.loads(mlmodel.user_defined_metadata['kraken_meta'])) return nn
def load_pyrnn(fname): """ Loads a legacy RNN from a pickle file. Args: fname (str): Path to the pickle object Returns: Unpickled object Raises: KrakenInvalidModelException on python 3, when unpickling fails, or the unpickled object is not a SeqRecognizer. """ logger.info(u'Trying to load pyrnn model from {}'.format(fname)) if not PY2: logger.error(u'Loading pickle models is not support on python 3') raise KrakenInvalidModelException('Loading pickle models is not ' 'supported on python 3') import cPickle def find_global(mname, cname): aliases = { 'lstm.lstm': kraken.lib.lstm, 'ocrolib.lstm': kraken.lib.lstm, 'ocrolib.lineest': kraken.lib.lineest, } if mname in aliases: return getattr(aliases[mname], cname) return getattr(sys.modules[mname], cname) of = io.open if fname.endswith('.gz'): of = gzip.open with io.BufferedReader(of(fname, 'rb')) as fp: unpickler = cPickle.Unpickler(fp) unpickler.find_global = find_global try: rnn = unpickler.load() except Exception as e: logger.error(u'Model file is not a pickle') raise KrakenInvalidModelException(str(e)) if not isinstance(rnn, kraken.lib.lstm.SeqRecognizer): logger.error(u'Model file is {} instead of SeqRecognizer'.format( type(rnn).__name__)) raise KrakenInvalidModelException('Pickle is {} instead of ' 'SeqRecognizer'.format( type(rnn).__name__)) return rnn
def load_pronn(fname): """ Loads a legacy pyrnn model in protobuf format and instantiates a kraken.lib.lstm.SeqRecognizer object. Args: fname (str): Path to the protobuf file Returns: A kraken.lib.lstm.SeqRecognizer object """ logger.info(u'Trying to load prornn model from {}'.format(fname)) with open(fname, 'rb') as fp: logger.debug(u'Initializing protobuf message') proto = pyrnn_pb2.pyrnn() try: proto.ParseFromString(fp.read()) except: logger.debug(u'File does not contain valid proto msg') raise KrakenInvalidModelException( 'File does not contain valid proto msg') if not proto.IsInitialized(): logger.debug(u'Message in file incomplete') raise KrakenInvalidModelException('Model incomplete') # extract codec logger.debug(u'Extracting codec') codec = kraken.lib.lstm.Codec().init(proto.codec) hiddensize = proto.fwdnet.wgi.dim[0] # next build a line estimator logger.debug(u'Add line estimator') lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput) network = kraken.lib.lstm.SeqRecognizer( lnorm.target_height, hiddensize, codec=codec, normalize=kraken.lib.lstm.normalize_nfkc) logger.debug(u'Setting weights on BIDILSTM') parallel, softmax = network.lstm.nets fwdnet, revnet = parallel.nets revnet = revnet.net for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'): fwd_ar = getattr(proto.fwdnet, w.lower()) rev_ar = getattr(proto.revnet, w.lower()) setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim)) setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim)) softmax.W2 = numpy.array(proto.softmax.w2.value).reshape( proto.softmax.w2.dim) return network
def load_any(fname: str, train: bool = False, device: str = 'cpu') -> TorchSeqRecognizer: """ Loads anything that was, is, and will be a valid ocropus model and instantiates a shiny new kraken.lib.lstm.SeqRecognizer from the RNN configuration in the file. Currently it recognizes the following kinds of models: * protobuf models containing converted python BIDILSTMs (recognition only) * protobuf models containing CLSTM networks (recognition only) * protobuf models containing VGSL segmentation and recognitino networks. Additionally an attribute 'kind' will be added to the SeqRecognizer containing a string representation of the source kind. Current known values are: * pyrnn for pickled BIDILSTMs * clstm for protobuf models generated by clstm * vgsl for VGSL models Args: fname: Path to the model train: Enables gradient calculation and dropout layers in model. device: Target device Returns: A kraken.lib.models.TorchSeqRecognizer object. Raises: KrakenInvalidModelException: if the model is not loadable by any parser. """ nn = None kind = '' fname = abspath(expandvars(expanduser(fname))) logger.info('Loading model from {}'.format(fname)) try: nn = TorchVGSLModel.load_model(str(fname)) kind = 'vgsl' except Exception: try: nn = TorchVGSLModel.load_clstm_model(fname) kind = 'clstm' except Exception: try: nn = TorchVGSLModel.load_pronn_model(fname) kind = 'pronn' except Exception: pass if not nn: raise KrakenInvalidModelException( 'File {} not loadable by any parser.'.format(fname)) seq = TorchSeqRecognizer(nn, train=train, device=device) seq.kind = kind return seq
def load_pyrnn(fname): """ Loads a legacy RNN from a pickle file. Args: fname (unicode): Path to the pickle object Returns: Unpickled object """ if not PY2: raise KrakenInvalidModelException('Loading pickle models is not ' 'supported on python 3') import cPickle def find_global(mname, cname): aliases = { 'lstm.lstm': kraken.lib.lstm, 'ocrolib.lstm': kraken.lib.lstm, 'ocrolib.lineest': kraken.lib.lineest, } if mname in aliases: return getattr(aliases[mname], cname) return getattr(sys.modules[mname], cname) of = io.open if fname.endswith(u'.gz'): of = gzip.open with io.BufferedReader(of(fname, 'rb')) as fp: unpickler = cPickle.Unpickler(fp) unpickler.find_global = find_global try: rnn = unpickler.load() except Exception as e: raise KrakenInvalidModelException(str(e)) if not isinstance(rnn, kraken.lib.lstm.SeqRecognizer): raise KrakenInvalidModelException('Pickle is %s instead of ' 'SeqRecognizer' % type(rnn).__name__) return rnn
def load_clstm(fname): """ Loads a CLSTM model in protobuf format and instantiates an object implementing the kraken.lib.SeqRecognizer interface. Args: fname (str): Path to the protobuf file Returns: A SeqRecognizer object """ try: import clstm except ImportError: raise KrakenInvalidModelException('No clstm module available') try: return ClstmSeqRecognizer(fname) except Exception as e: raise KrakenInvalidModelException(str(e))
def load_pronn(fname): """ Loads a legacy pyrnn model in protobuf format and instantiates a kraken.lib.lstm.SeqRecognizer object. Args: fname (unicode): Path to the HDF5 file Returns: A kraken.lib.lstm.SeqRecognizer object """ with open(fname, 'rb') as fp: proto = pyrnn_pb2.pyrnn() try: proto.ParseFromString(fp.read()) except: raise KrakenInvalidModelException( 'File does not contain valid proto msg') if not proto.IsInitialized(): raise KrakenInvalidModelException('Model incomplete') # extract codec codec = kraken.lib.lstm.Codec().init(proto.codec) hiddensize = proto.fwdnet.wgi.dim[0] # next build a line estimator lnorm = kraken.lib.lineest.CenterNormalizer(proto.ninput) network = kraken.lib.lstm.SeqRecognizer( lnorm.target_height, hiddensize, codec=codec, normalize=kraken.lib.lstm.normalize_nfkc) parallel, softmax = network.lstm.nets fwdnet, revnet = parallel.nets revnet = revnet.net for w in ('WGI', 'WGF', 'WGO', 'WCI', 'WIP', 'WFP', 'WOP'): fwd_ar = getattr(proto.fwdnet, w.lower()) rev_ar = getattr(proto.revnet, w.lower()) setattr(fwdnet, w, numpy.array(fwd_ar.value).reshape(fwd_ar.dim)) setattr(revnet, w, numpy.array(rev_ar.value).reshape(rev_ar.dim)) softmax.W2 = numpy.array(proto.softmax.w2.value).reshape( proto.softmax.w2.dim) return network
def load_clstm_model(cls, path: str): """ Loads an CLSTM model to VGSL. """ net = clstm_pb2.NetworkProto() with open(path, 'rb') as fp: try: net.ParseFromString(fp.read()) except Exception: raise KrakenInvalidModelException('File does not contain valid proto msg') if not net.IsInitialized(): raise KrakenInvalidModelException('Model incomplete') input = net.ninput attrib = {a.key: a.value for a in list(net.attribute)} # mainline clstm model if len(attrib) > 1: mode = 'clstm' else: mode = 'clstm_compat' # extract codec codec = PytorchCodec([''] + [chr(x) for x in net.codec[1:]]) # separate layers nets = {} nets['softm'] = [n for n in list(net.sub) if n.kind == 'SoftmaxLayer'][0] parallel = [n for n in list(net.sub) if n.kind == 'Parallel'][0] nets['lstm1'] = [n for n in list(parallel.sub) if n.kind.startswith('NPLSTM')][0] rev = [n for n in list(parallel.sub) if n.kind == 'Reversed'][0] nets['lstm2'] = rev.sub[0] hidden = int(nets['lstm1'].attribute[0].value) weights = {} # type: Dict[str, torch.Tensor] for n in nets: weights[n] = {} for w in list(nets[n].weights): weights[n][w.name] = torch.Tensor(w.value).view(list(w.dim)) if mode == 'clstm_compat': weightnames = ('.WGI', '.WGF', '.WCI', '.WGO') weightname_softm = '.W' else: weightnames = ('WGI', 'WGF', 'WCI', 'WGO') weightname_softm = 'W1' # input hidden and hidden-hidden weights are in one matrix. also # CLSTM/ocropy likes 1-augmenting every other tensor so the ih weights # are input+1 in one dimension. t = torch.cat(list(w for w in [weights['lstm1'][wn] for wn in weightnames])) weight_ih_l0 = t[:, :input+1] weight_hh_l0 = t[:, input+1:] t = torch.cat(list(w for w in [weights['lstm2'][wn] for wn in weightnames])) weight_ih_l0_rev = t[:, :input+1] weight_hh_l0_rev = t[:, input+1:] weight_lin = weights['softm'][weightname_softm] if mode == 'clstm_compat': weight_lin = torch.cat([torch.zeros(len(weight_lin), 1), weight_lin], 1) # build vgsl spec and set weights nn = cls('[1,1,0,{} Lbxc{} O1ca{}]'.format(input, hidden, len(net.codec))) nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0) nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0) nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev) nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev) nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin) nn.add_codec(codec) return nn
def load_pyrnn_model(cls, path: str): """ Loads an pyrnn model to VGSL. """ if not PY2: raise KrakenInvalidModelException('Loading pickle models is not supported on python 3') import cPickle def find_global(mname, cname): aliases = { 'lstm.lstm': kraken.lib.lstm, 'ocrolib.lstm': kraken.lib.lstm, 'ocrolib.lineest': kraken.lib.lineest, } if mname in aliases: return getattr(aliases[mname], cname) return getattr(sys.modules[mname], cname) of = io.open if path.endswith('.gz'): of = gzip.open with io.BufferedReader(of(path, 'rb')) as fp: unpickler = cPickle.Unpickler(fp) unpickler.find_global = find_global try: net = unpickler.load() except Exception as e: raise KrakenInvalidModelException(str(e)) if not isinstance(net, kraken.lib.lstm.SeqRecognizer): raise KrakenInvalidModelException('Pickle is %s instead of ' 'SeqRecognizer' % type(net).__name__) # extract codec codec = PytorchCodec({k: [v] for k, v in net.codec.char2code.items()}) input = net.Ni parallel, softmax = net.lstm.nets fwdnet, revnet = parallel.nets revnet = revnet.net hidden = fwdnet.WGI.shape[0] # extract weights weightnames = ('WGI', 'WGF', 'WCI', 'WGO', 'WIP', 'WFP', 'WOP') fwd_w = [] rev_w = [] for w in weightnames: fwd_w.append(torch.Tensor(getattr(fwdnet, w))) rev_w.append(torch.Tensor(getattr(revnet, w))) t = torch.cat(fwd_w[:4]) weight_ih_l0 = t[:, :input+1] weight_hh_l0 = t[:, input+1:] t = torch.cat(rev_w[:4]) weight_ih_l0_rev = t[:, :input+1] weight_hh_l0_rev = t[:, input+1:] weight_lin = torch.Tensor(softmax.W2) # build vgsl spec and set weights nn = cls('[1,1,0,{} Lbxo{} O1ca{}]'.format(input, hidden, len(net.codec.code2char))) nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0) nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0) nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev) nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev) nn.nn.L_0.layer.weight_ip_l0 = torch.nn.Parameter(fwd_w[4]) nn.nn.L_0.layer.weight_fp_l0 = torch.nn.Parameter(fwd_w[5]) nn.nn.L_0.layer.weight_op_l0 = torch.nn.Parameter(fwd_w[6]) nn.nn.L_0.layer.weight_ip_l0_reverse = torch.nn.Parameter(rev_w[4]) nn.nn.L_0.layer.weight_fp_l0_reverse = torch.nn.Parameter(rev_w[5]) nn.nn.L_0.layer.weight_op_l0_reverse = torch.nn.Parameter(rev_w[6]) nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin) nn.add_codec(codec) return nn
def segment(im: PIL.Image.Image, text_direction: str = 'horizontal-lr', mask: Optional[np.ndarray] = None, reading_order_fn: Callable = polygonal_reading_order, model: Union[List[vgsl.TorchVGSLModel], vgsl.TorchVGSLModel] = None, device: str = 'cpu') -> Dict[str, Any]: r""" Segments a page into text lines using the baseline segmenter. Segments a page into text lines and returns the polyline formed by each baseline and their estimated environment. Args: im: Input image. The mode can generally be anything but it is possible to supply a binarized-input-only model which requires accordingly treated images. text_direction: Passed-through value for serialization.serialize. mask: A bi-level mask image of the same size as `im` where 0-valued regions are ignored for segmentation purposes. Disables column detection. reading_order_fn: Function to determine the reading order. Has to accept a list of tuples (baselines, polygon) and a text direction (`lr` or `rl`). model: One or more TorchVGSLModel containing a segmentation model. If none is given a default model will be loaded. device: The target device to run the neural network on. Returns: A dictionary containing the text direction and under the key 'lines' a list of reading order sorted baselines (polylines) and their respective polygonal boundaries. The last and first point of each boundary polygon are connected. .. code-block:: :force: {'text_direction': '$dir', 'type': 'baseline', 'lines': [ {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} ] 'regions': [ {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, {'region': [[x0, ...]], 'type': 'text'} ] } Raises: KrakenInvalidModelException: if the given model is not a valid segmentation model. KrakenInputException: if the mask is not bitonal or does not match the image size. """ if model is None: logger.info('No segmentation model given. Loading default model.') model = vgsl.TorchVGSLModel.load_model( pkg_resources.resource_filename(__name__, 'blla.mlmodel')) if isinstance(model, vgsl.TorchVGSLModel): model = [model] for nn in model: if nn.model_type != 'segmentation': raise KrakenInvalidModelException( f'Invalid model type {nn.model_type} for {nn}') if 'class_mapping' not in nn.user_metadata: raise KrakenInvalidModelException( f'Segmentation model {nn} does not contain valid class mapping' ) im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') for net in model: if 'topline' in net.user_metadata: loc = { None: 'center', True: 'top', False: 'bottom' }[net.user_metadata['topline']] logger.debug(f'Baseline location: {loc}') rets = compute_segmentation_map(im, mask, net, device) regions = vec_regions(**rets) # flatten regions for line ordering/fetch bounding regions line_regs = [] suppl_obj = [] for cls, regs in regions.items(): line_regs.extend(regs) if rets['bounding_regions'] is not None and cls in rets[ 'bounding_regions']: suppl_obj.extend(regs) # convert back to net scale suppl_obj = scale_regions(suppl_obj, 1 / rets['scale']) line_regs = scale_regions(line_regs, 1 / rets['scale']) lines = vec_lines(**rets, regions=line_regs, reading_order_fn=reading_order_fn, text_direction=text_direction, suppl_obj=suppl_obj, topline=net.user_metadata['topline'] if 'topline' in net.user_metadata else False) if len(rets['cls_map']['baselines']) > 1: script_detection = True else: script_detection = False return { 'text_direction': text_direction, 'type': 'baselines', 'lines': lines, 'regions': regions, 'script_detection': script_detection }